diff --git a/README.md b/README.md index 095bb351..76cae4ee 100644 --- a/README.md +++ b/README.md @@ -84,9 +84,9 @@ the employed overload of `WithTools` examines the current assembly for classes w `McpTool` attribute as tools.) ```csharp -using ModelContextProtocol; -using ModelContextProtocol.Server; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; +using ModelContextProtocol.Server; using System.ComponentModel; var builder = Host.CreateEmptyApplicationBuilder(settings: null); @@ -109,7 +109,7 @@ the connected client. Similarly, arguments may be injected via dependency inject `IMcpServer` to make sampling requests back to the client in order to summarize content it downloads from the specified url via an `HttpClient` injected via dependency injection. ```csharp -[McpServerTool("SummarizeContentFromUrl"), Description("Summarizes content downloaded from a specific URI")] +[McpServerTool(Name = "SummarizeContentFromUrl"), Description("Summarizes content downloaded from a specific URI")] public static async Task SummarizeDownloadedContent( IMcpServer thisServer, HttpClient httpClient, @@ -122,8 +122,8 @@ public static async Task SummarizeDownloadedContent( [ new(ChatRole.User, "Briefly summarize the following downloaded content:"), new(ChatRole.User, content), - ] - + ]; + ChatOptions options = new() { MaxOutputTokens = 256, @@ -134,13 +134,24 @@ public static async Task SummarizeDownloadedContent( } ``` +Prompts can be exposed in a similar manner, using `[McpServerPrompt]`, e.g. +```csharp +[McpServerPromptType] +public static class MyPrompts +{ + [McpServerPrompt, Description("Creates a prompt to summarize the provided message.")] + public static ChatMessage Summarize([Description("The content to summarize")] string content) => + new(ChatRole.User, $"Please summarize this content into a single sentence: {content}"); +} +``` + More control is also available, with fine-grained control over configuring the server and how it should handle client requests. For example: ```csharp using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; -using Microsoft.Extensions.Logging.Abstractions; +using System.Text.Json; McpServerOptions options = new() { @@ -149,9 +160,8 @@ McpServerOptions options = new() { Tools = new() { - ListToolsHandler = async (request, cancellationToken) => - { - return new ListToolsResult() + ListToolsHandler = (request, cancellationToken) => + Task.FromResult(new ListToolsResult() { Tools = [ @@ -173,10 +183,9 @@ McpServerOptions options = new() """), } ] - }; - }, + }), - CallToolHandler = async (request, cancellationToken) => + CallToolHandler = (request, cancellationToken) => { if (request.Params?.Name == "echo") { @@ -185,10 +194,10 @@ McpServerOptions options = new() throw new McpServerException("Missing required argument 'message'"); } - return new CallToolResponse() + return Task.FromResult(new CallToolResponse() { Content = [new Content() { Text = $"Echo: {message}", Type = "text" }] - }; + }); } throw new McpServerException($"Unknown tool: '{request.Params?.Name}'"); diff --git a/samples/AspNetCoreSseServer/Program.cs b/samples/AspNetCoreSseServer/Program.cs index a3cd9414..b66b2ce4 100644 --- a/samples/AspNetCoreSseServer/Program.cs +++ b/samples/AspNetCoreSseServer/Program.cs @@ -1,4 +1,3 @@ -using ModelContextProtocol; using AspNetCoreSseServer; var builder = WebApplication.CreateBuilder(args); diff --git a/samples/QuickstartWeatherServer/Program.cs b/samples/QuickstartWeatherServer/Program.cs index fbc2b44b..a191cb16 100644 --- a/samples/QuickstartWeatherServer/Program.cs +++ b/samples/QuickstartWeatherServer/Program.cs @@ -1,6 +1,5 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; -using ModelContextProtocol; using System.Net.Http.Headers; var builder = Host.CreateEmptyApplicationBuilder(settings: null); diff --git a/samples/TestServerWithHosting/Program.cs b/samples/TestServerWithHosting/Program.cs index 82b731df..ee009084 100644 --- a/samples/TestServerWithHosting/Program.cs +++ b/samples/TestServerWithHosting/Program.cs @@ -1,4 +1,4 @@ -using ModelContextProtocol; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Hosting; using Serilog; diff --git a/src/ModelContextProtocol/AIContentExtensions.cs b/src/ModelContextProtocol/AIContentExtensions.cs index 0fe00fb4..4a75f1f0 100644 --- a/src/ModelContextProtocol/AIContentExtensions.cs +++ b/src/ModelContextProtocol/AIContentExtensions.cs @@ -25,6 +25,37 @@ public static ChatMessage ToChatMessage(this PromptMessage promptMessage) }; } + /// Creates s from a . + /// The messages to convert. + /// The created . + public static IList ToChatMessages(this GetPromptResult promptResult) + { + Throw.IfNull(promptResult); + + return promptResult.Messages.Select(m => m.ToChatMessage()).ToList(); + } + + /// Gets instances for the specified . + /// The message for which to extract its contents as instances. + /// The converted content. + public static IList ToPromptMessages(this ChatMessage chatMessage) + { + Throw.IfNull(chatMessage); + + Role r = chatMessage.Role == ChatRole.User ? Role.User : Role.Assistant; + + List messages = []; + foreach (var content in chatMessage.Contents) + { + if (content is TextContent or DataContent) + { + messages.Add(new PromptMessage { Role = r, Content = content.ToContent() }); + } + } + + return messages; + } + /// Creates a new from the content of a . /// The to convert. /// The created . diff --git a/src/ModelContextProtocol/Client/McpClient.cs b/src/ModelContextProtocol/Client/McpClient.cs index f2e0fa5f..3b81c7e2 100644 --- a/src/ModelContextProtocol/Client/McpClient.cs +++ b/src/ModelContextProtocol/Client/McpClient.cs @@ -1,5 +1,4 @@ using Microsoft.Extensions.Logging; -using ModelContextProtocol.Configuration; using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; diff --git a/src/ModelContextProtocol/Client/McpClientExtensions.cs b/src/ModelContextProtocol/Client/McpClientExtensions.cs index 5227285a..36742b14 100644 --- a/src/ModelContextProtocol/Client/McpClientExtensions.cs +++ b/src/ModelContextProtocol/Client/McpClientExtensions.cs @@ -115,13 +115,12 @@ public static async IAsyncEnumerable EnumerateToolsAsync( /// The client. /// A token to cancel the operation. /// A list of all available prompts. - public static async Task> ListPromptsAsync( + public static async Task> ListPromptsAsync( this IMcpClient client, CancellationToken cancellationToken = default) { Throw.IfNull(client); - List? prompts = null; - + List? prompts = null; string? cursor = null; do { @@ -129,13 +128,10 @@ public static async Task> ListPromptsAsync( CreateRequest(RequestMethods.PromptsList, CreateCursorDictionary(cursor)), cancellationToken).ConfigureAwait(false); - if (prompts is null) - { - prompts = promptResults.Prompts; - } - else + prompts ??= new List(promptResults.Prompts.Count); + foreach (var prompt in promptResults.Prompts) { - prompts.AddRange(promptResults.Prompts); + prompts.Add(new McpClientPrompt(client, prompt)); } cursor = promptResults.NextCursor; @@ -186,7 +182,7 @@ public static async IAsyncEnumerable EnumeratePromptsAsync( /// A token to cancel the operation. /// A task containing the prompt's content and messages. public static Task GetPromptAsync( - this IMcpClient client, string name, Dictionary? arguments = null, CancellationToken cancellationToken = default) + this IMcpClient client, string name, IReadOnlyDictionary? arguments = null, CancellationToken cancellationToken = default) { Throw.IfNull(client); Throw.IfNullOrWhiteSpace(name); @@ -345,7 +341,7 @@ public static Task ReadResourceAsync( Throw.IfNullOrWhiteSpace(uri); return client.SendRequestAsync( - CreateRequest(RequestMethods.ResourcesRead, new() { ["uri"] = uri }), + CreateRequest(RequestMethods.ResourcesRead, new Dictionary() { ["uri"] = uri }), cancellationToken); } @@ -369,7 +365,7 @@ public static Task GetCompletionAsync(this IMcpClient client, Re } return client.SendRequestAsync( - CreateRequest(RequestMethods.CompletionComplete, new() + CreateRequest(RequestMethods.CompletionComplete, new Dictionary() { ["ref"] = reference, ["argument"] = new Argument { Name = argumentName, Value = argumentValue } @@ -389,7 +385,7 @@ public static Task SubscribeToResourceAsync(this IMcpClient client, string uri, Throw.IfNullOrWhiteSpace(uri); return client.SendRequestAsync( - CreateRequest(RequestMethods.ResourcesSubscribe, new() { ["uri"] = uri }), + CreateRequest(RequestMethods.ResourcesSubscribe, new Dictionary() { ["uri"] = uri }), cancellationToken); } @@ -405,7 +401,7 @@ public static Task UnsubscribeFromResourceAsync(this IMcpClient client, string u Throw.IfNullOrWhiteSpace(uri); return client.SendRequestAsync( - CreateRequest(RequestMethods.ResourcesUnsubscribe, new() { ["uri"] = uri }), + CreateRequest(RequestMethods.ResourcesUnsubscribe, new Dictionary() { ["uri"] = uri }), cancellationToken); } @@ -560,11 +556,11 @@ public static Task SetLoggingLevel(this IMcpClient client, LoggingLevel level, C Throw.IfNull(client); return client.SendRequestAsync( - CreateRequest(RequestMethods.LoggingSetLevel, new() { ["level"] = level }), + CreateRequest(RequestMethods.LoggingSetLevel, new Dictionary() { ["level"] = level }), cancellationToken); } - private static JsonRpcRequest CreateRequest(string method, Dictionary? parameters) => + private static JsonRpcRequest CreateRequest(string method, IReadOnlyDictionary? parameters) => new() { Method = method, diff --git a/src/ModelContextProtocol/Client/McpClientFactory.cs b/src/ModelContextProtocol/Client/McpClientFactory.cs index f50329f8..97c2371f 100644 --- a/src/ModelContextProtocol/Client/McpClientFactory.cs +++ b/src/ModelContextProtocol/Client/McpClientFactory.cs @@ -1,6 +1,5 @@ using System.Globalization; using System.Runtime.InteropServices; -using ModelContextProtocol.Configuration; using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Utils; diff --git a/src/ModelContextProtocol/Client/McpClientPrompt.cs b/src/ModelContextProtocol/Client/McpClientPrompt.cs new file mode 100644 index 00000000..8deed8eb --- /dev/null +++ b/src/ModelContextProtocol/Client/McpClientPrompt.cs @@ -0,0 +1,41 @@ +using ModelContextProtocol.Protocol.Types; + +namespace ModelContextProtocol.Client; + +/// Provides an invocable prompt. +public sealed class McpClientPrompt +{ + private readonly IMcpClient _client; + + internal McpClientPrompt(IMcpClient client, Prompt prompt) + { + _client = client; + ProtocolPrompt = prompt; + } + + /// Gets the protocol type for this instance. + public Prompt ProtocolPrompt { get; } + + /// + /// Retrieves a specific prompt with optional arguments. + /// + /// Optional arguments for the prompt + /// A token to cancel the operation. + /// A task containing the prompt's content and messages. + public async ValueTask GetAsync( + IEnumerable>? arguments = null, + CancellationToken cancellationToken = default) + { + IReadOnlyDictionary? argDict = + arguments as IReadOnlyDictionary ?? + arguments?.ToDictionary(); + + return await _client.GetPromptAsync(ProtocolPrompt.Name, argDict, cancellationToken).ConfigureAwait(false); + } + + /// Gets the name of the prompt. + public string Name => ProtocolPrompt.Name; + + /// Gets a description of the prompt. + public string? Description => ProtocolPrompt.Description; +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Configuration/DefaultMcpServerBuilder.cs b/src/ModelContextProtocol/Configuration/DefaultMcpServerBuilder.cs index ed686eab..6b46c8e3 100644 --- a/src/ModelContextProtocol/Configuration/DefaultMcpServerBuilder.cs +++ b/src/ModelContextProtocol/Configuration/DefaultMcpServerBuilder.cs @@ -1,12 +1,11 @@ using ModelContextProtocol.Utils; -using Microsoft.Extensions.DependencyInjection; -namespace ModelContextProtocol.Configuration; +namespace Microsoft.Extensions.DependencyInjection; /// /// Default implementation of . /// -internal class DefaultMcpServerBuilder : IMcpServerBuilder +internal sealed class DefaultMcpServerBuilder : IMcpServerBuilder { /// public IServiceCollection Services { get; } diff --git a/src/ModelContextProtocol/Configuration/IMcpServerBuilder.cs b/src/ModelContextProtocol/Configuration/IMcpServerBuilder.cs index 219bd411..5ef3971c 100644 --- a/src/ModelContextProtocol/Configuration/IMcpServerBuilder.cs +++ b/src/ModelContextProtocol/Configuration/IMcpServerBuilder.cs @@ -1,7 +1,6 @@ using ModelContextProtocol.Server; -using Microsoft.Extensions.DependencyInjection; -namespace ModelContextProtocol.Configuration; +namespace Microsoft.Extensions.DependencyInjection; /// /// Builder for configuring instances. diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Handler.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Handler.cs deleted file mode 100644 index 3612b925..00000000 --- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Handler.cs +++ /dev/null @@ -1,143 +0,0 @@ -using ModelContextProtocol.Configuration; -using ModelContextProtocol.Protocol.Types; -using ModelContextProtocol.Server; -using ModelContextProtocol.Utils; -using Microsoft.Extensions.DependencyInjection; - -namespace ModelContextProtocol; - -/// -/// Extension to configure the MCP server with handlers -/// -public static partial class McpServerBuilderExtensions -{ - /// - /// Sets the handler for list resource templates requests. - /// - /// The builder instance. - /// The handler. - public static IMcpServerBuilder WithListResourceTemplatesHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) - { - Throw.IfNull(builder); - - builder.Services.Configure(s => s.ListResourceTemplatesHandler = handler); - return builder; - } - - /// - /// Sets the handler for list tools requests. - /// - /// The builder instance. - /// The handler. - public static IMcpServerBuilder WithListToolsHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) - { - Throw.IfNull(builder); - - builder.Services.Configure(s => s.ListToolsHandler = handler); - return builder; - } - - /// - /// Sets the handler for call tool requests. - /// - /// The builder instance. - /// The handler. - public static IMcpServerBuilder WithCallToolHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) - { - Throw.IfNull(builder); - - builder.Services.Configure(s => s.CallToolHandler = handler); - return builder; - } - - /// - /// Sets the handler for list prompts requests. - /// - /// The builder instance. - /// The handler. - public static IMcpServerBuilder WithListPromptsHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) - { - Throw.IfNull(builder); - - builder.Services.Configure(s => s.ListPromptsHandler = handler); - return builder; - } - - /// - /// Sets the handler for get prompt requests. - /// - /// The builder instance. - /// The handler. - public static IMcpServerBuilder WithGetPromptHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) - { - Throw.IfNull(builder); - - builder.Services.Configure(s => s.GetPromptHandler = handler); - return builder; - } - - /// - /// Sets the handler for list resources requests. - /// - /// The builder instance. - /// The handler. - public static IMcpServerBuilder WithListResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) - { - Throw.IfNull(builder); - - builder.Services.Configure(s => s.ListResourcesHandler = handler); - return builder; - } - - /// - /// Sets the handler for read resources requests. - /// - /// The builder instance. - /// The handler. - public static IMcpServerBuilder WithReadResourceHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) - { - Throw.IfNull(builder); - - builder.Services.Configure(s => s.ReadResourceHandler = handler); - return builder; - } - - /// - /// Sets the handler for get completion requests. - /// - /// The builder instance. - /// The handler. - public static IMcpServerBuilder WithGetCompletionHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) - { - Throw.IfNull(builder); - - builder.Services.Configure(s => s.GetCompletionHandler = handler); - return builder; - } - - /// - /// Sets the handler for subscribe to resources messages. - /// - /// The builder instance. - /// The handler. - public static IMcpServerBuilder WithSubscribeToResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) - { - Throw.IfNull(builder); - - builder.Services.Configure(s => s.SubscribeToResourcesHandler = handler); - return builder; - } - - /// - /// Sets or sets the handler for subscribe to resources messages. - /// - /// The builder instance. - /// The handler. - public static IMcpServerBuilder WithUnsubscribeFromResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) - { - Throw.IfNull(builder); - - builder.Services.Configure(s => s.UnsubscribeFromResourcesHandler = handler); - return builder; - } -} diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Tools.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Tools.cs deleted file mode 100644 index 2dc0772c..00000000 --- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Tools.cs +++ /dev/null @@ -1,102 +0,0 @@ -using ModelContextProtocol.Configuration; -using ModelContextProtocol.Server; -using ModelContextProtocol.Utils; -using System.Diagnostics.CodeAnalysis; -using System.Reflection; -using Microsoft.Extensions.DependencyInjection; - -namespace ModelContextProtocol; - -/// -/// Extension to configure the MCP server with tools -/// -public static partial class McpServerBuilderExtensions -{ - private const string WithToolsRequiresUnreferencedCodeMessage = - $"The non-generic {nameof(WithTools)} and {nameof(WithToolsFromAssembly)} methods require dynamic lookup of method metadata" + - $"and may not work in Native AOT. Use the generic {nameof(WithTools)} method instead."; - - /// Adds instances to the service collection backing . - /// The tool type. - /// The builder instance. - /// is . - /// - /// This method discovers all instance and static methods (public and non-public) on the specified - /// type, where the methods are attributed as , and adds an - /// instance for each. For instance methods, an instance will be constructed for each invocation of the tool. - /// - public static IMcpServerBuilder WithTools<[DynamicallyAccessedMembers( - DynamicallyAccessedMemberTypes.PublicMethods | - DynamicallyAccessedMemberTypes.NonPublicMethods | - DynamicallyAccessedMemberTypes.PublicConstructors)] TTool>( - this IMcpServerBuilder builder) - { - Throw.IfNull(builder); - - foreach (var toolMethod in typeof(TTool).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) - { - if (toolMethod.GetCustomAttribute() is not null) - { - builder.Services.AddSingleton((Func)(toolMethod.IsStatic ? - services => McpServerTool.Create(toolMethod, options: new() { Services = services }) : - services => McpServerTool.Create(toolMethod, typeof(TTool), new() { Services = services }))); - } - } - - return builder; - } - - /// Adds instances to the service collection backing . - /// The builder instance. - /// Types with marked methods to add as tools to the server. - /// is . - /// is . - /// - /// This method discovers all instance and static methods (public and non-public) on the specified - /// types, where the methods are attributed as , and adds an - /// instance for each. For instance methods, an instance will be constructed for each invocation of the tool. - /// - [RequiresUnreferencedCode(WithToolsRequiresUnreferencedCodeMessage)] - public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, params IEnumerable toolTypes) - { - Throw.IfNull(builder); - Throw.IfNull(toolTypes); - - foreach (var toolType in toolTypes) - { - if (toolType is not null) - { - foreach (var toolMethod in toolType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) - { - if (toolMethod.GetCustomAttribute() is not null) - { - builder.Services.AddSingleton((Func)(toolMethod.IsStatic ? - services => McpServerTool.Create(toolMethod, options: new() { Services = services }) : - services => McpServerTool.Create(toolMethod, toolType, new() { Services = services }))); - } - } - } - } - - return builder; - } - - /// - /// Adds types marked with the attribute from the given assembly as tools to the server. - /// - /// The builder instance. - /// The assembly to load the types from. Null to get the current assembly - /// is . - [RequiresUnreferencedCode(WithToolsRequiresUnreferencedCodeMessage)] - public static IMcpServerBuilder WithToolsFromAssembly(this IMcpServerBuilder builder, Assembly? toolAssembly = null) - { - Throw.IfNull(builder); - - toolAssembly ??= Assembly.GetCallingAssembly(); - - return builder.WithTools( - from t in toolAssembly.GetTypes() - where t.GetCustomAttribute() is not null - select t); - } -} diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs deleted file mode 100644 index 938d1e10..00000000 --- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.Transports.cs +++ /dev/null @@ -1,52 +0,0 @@ -using ModelContextProtocol.Configuration; -using ModelContextProtocol.Hosting; -using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Utils; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Options; -using ModelContextProtocol.Server; - -namespace ModelContextProtocol; - -/// -/// Extension to configure the MCP server with transports -/// -public static partial class McpServerBuilderExtensions -{ - /// - /// Adds a server transport that uses stdin/stdout for communication. - /// - /// The builder instance. - public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder builder) - { - Throw.IfNull(builder); - - builder.Services.AddSingleton(); - builder.Services.AddHostedService(); - - builder.Services.AddSingleton(services => - { - ITransport serverTransport = services.GetRequiredService(); - IOptions options = services.GetRequiredService>(); - ILoggerFactory? loggerFactory = services.GetService(); - - return McpServerFactory.Create(serverTransport, options.Value, loggerFactory, services); - }); - - return builder; - } - - /// - /// Adds a server transport that uses SSE via a HttpListener for communication. - /// - /// The builder instance. - public static IMcpServerBuilder WithHttpListenerSseServerTransport(this IMcpServerBuilder builder) - { - Throw.IfNull(builder); - - builder.Services.AddSingleton(); - builder.Services.AddHostedService(); - return builder; - } -} diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs new file mode 100644 index 00000000..06b111e6 --- /dev/null +++ b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs @@ -0,0 +1,367 @@ +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Hosting; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Server; +using ModelContextProtocol.Utils; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; + +namespace Microsoft.Extensions.DependencyInjection; + +/// +/// Provides methods for configuring MCP servers via dependency injection. +/// +public static partial class McpServerBuilderExtensions +{ + #region WithTools + private const string WithToolsRequiresUnreferencedCodeMessage = + $"The non-generic {nameof(WithTools)} and {nameof(WithToolsFromAssembly)} methods require dynamic lookup of method metadata" + + $"and may not work in Native AOT. Use the generic {nameof(WithTools)} method instead."; + + /// Adds instances to the service collection backing . + /// The tool type. + /// The builder instance. + /// is . + /// + /// This method discovers all instance and static methods (public and non-public) on the specified + /// type, where the methods are attributed as , and adds an + /// instance for each. For instance methods, an instance will be constructed for each invocation of the tool. + /// + public static IMcpServerBuilder WithTools<[DynamicallyAccessedMembers( + DynamicallyAccessedMemberTypes.PublicMethods | + DynamicallyAccessedMemberTypes.NonPublicMethods | + DynamicallyAccessedMemberTypes.PublicConstructors)] TToolType>( + this IMcpServerBuilder builder) + { + Throw.IfNull(builder); + + foreach (var toolMethod in typeof(TToolType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) + { + if (toolMethod.GetCustomAttribute() is not null) + { + builder.Services.AddSingleton((Func)(toolMethod.IsStatic ? + services => McpServerTool.Create(toolMethod, options: new() { Services = services }) : + services => McpServerTool.Create(toolMethod, typeof(TToolType), new() { Services = services }))); + } + } + + return builder; + } + + /// Adds instances to the service collection backing . + /// The builder instance. + /// Types with marked methods to add as tools to the server. + /// is . + /// is . + /// + /// This method discovers all instance and static methods (public and non-public) on the specified + /// types, where the methods are attributed as , and adds an + /// instance for each. For instance methods, an instance will be constructed for each invocation of the tool. + /// + [RequiresUnreferencedCode(WithToolsRequiresUnreferencedCodeMessage)] + public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, params IEnumerable toolTypes) + { + Throw.IfNull(builder); + Throw.IfNull(toolTypes); + + foreach (var toolType in toolTypes) + { + if (toolType is not null) + { + foreach (var toolMethod in toolType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) + { + if (toolMethod.GetCustomAttribute() is not null) + { + builder.Services.AddSingleton((Func)(toolMethod.IsStatic ? + services => McpServerTool.Create(toolMethod, options: new() { Services = services }) : + services => McpServerTool.Create(toolMethod, toolType, new() { Services = services }))); + } + } + } + } + + return builder; + } + + /// + /// Adds types marked with the attribute from the given assembly as tools to the server. + /// + /// The builder instance. + /// The assembly to load the types from. Null to get the current assembly + /// is . + [RequiresUnreferencedCode(WithToolsRequiresUnreferencedCodeMessage)] + public static IMcpServerBuilder WithToolsFromAssembly(this IMcpServerBuilder builder, Assembly? toolAssembly = null) + { + Throw.IfNull(builder); + + toolAssembly ??= Assembly.GetCallingAssembly(); + + return builder.WithTools( + from t in toolAssembly.GetTypes() + where t.GetCustomAttribute() is not null + select t); + } + #endregion + + #region WithPrompts + private const string WithPromptsRequiresUnreferencedCodeMessage = + $"The non-generic {nameof(WithPrompts)} and {nameof(WithPromptsFromAssembly)} methods require dynamic lookup of method metadata" + + $"and may not work in Native AOT. Use the generic {nameof(WithPrompts)} method instead."; + + /// Adds instances to the service collection backing . + /// The prompt type. + /// The builder instance. + /// is . + /// + /// This method discovers all instance and static methods (public and non-public) on the specified + /// type, where the methods are attributed as , and adds an + /// instance for each. For instance methods, an instance will be constructed for each invocation of the prompt. + /// + public static IMcpServerBuilder WithPrompts<[DynamicallyAccessedMembers( + DynamicallyAccessedMemberTypes.PublicMethods | + DynamicallyAccessedMemberTypes.NonPublicMethods | + DynamicallyAccessedMemberTypes.PublicConstructors)] TPromptType>( + this IMcpServerBuilder builder) + { + Throw.IfNull(builder); + + foreach (var promptMethod in typeof(TPromptType).GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) + { + if (promptMethod.GetCustomAttribute() is not null) + { + builder.Services.AddSingleton((Func)(promptMethod.IsStatic ? + services => McpServerPrompt.Create(promptMethod, options: new() { Services = services }) : + services => McpServerPrompt.Create(promptMethod, typeof(TPromptType), new() { Services = services }))); + } + } + + return builder; + } + + /// Adds instances to the service collection backing . + /// The builder instance. + /// Types with marked methods to add as prompts to the server. + /// is . + /// is . + /// + /// This method discovers all instance and static methods (public and non-public) on the specified + /// types, where the methods are attributed as , and adds an + /// instance for each. For instance methods, an instance will be constructed for each invocation of the prompt. + /// + [RequiresUnreferencedCode(WithPromptsRequiresUnreferencedCodeMessage)] + public static IMcpServerBuilder WithPrompts(this IMcpServerBuilder builder, params IEnumerable promptTypes) + { + Throw.IfNull(builder); + Throw.IfNull(promptTypes); + + foreach (var promptType in promptTypes) + { + if (promptType is not null) + { + foreach (var promptMethod in promptType.GetMethods(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance)) + { + if (promptMethod.GetCustomAttribute() is not null) + { + builder.Services.AddSingleton((Func)(promptMethod.IsStatic ? + services => McpServerPrompt.Create(promptMethod, options: new() { Services = services }) : + services => McpServerPrompt.Create(promptMethod, promptType, new() { Services = services }))); + } + } + } + } + + return builder; + } + + /// + /// Adds types marked with the attribute from the given assembly as prompts to the server. + /// + /// The builder instance. + /// The assembly to load the types from. Null to get the current assembly + /// is . + [RequiresUnreferencedCode(WithPromptsRequiresUnreferencedCodeMessage)] + public static IMcpServerBuilder WithPromptsFromAssembly(this IMcpServerBuilder builder, Assembly? promptAssembly = null) + { + Throw.IfNull(builder); + + promptAssembly ??= Assembly.GetCallingAssembly(); + + return builder.WithPrompts( + from t in promptAssembly.GetTypes() + where t.GetCustomAttribute() is not null + select t); + } + #endregion + + #region Handlers + /// + /// Sets the handler for list resource templates requests. + /// + /// The builder instance. + /// The handler. + public static IMcpServerBuilder WithListResourceTemplatesHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + { + Throw.IfNull(builder); + + builder.Services.Configure(s => s.ListResourceTemplatesHandler = handler); + return builder; + } + + /// + /// Sets the handler for list tools requests. + /// + /// The builder instance. + /// The handler. + public static IMcpServerBuilder WithListToolsHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + { + Throw.IfNull(builder); + + builder.Services.Configure(s => s.ListToolsHandler = handler); + return builder; + } + + /// + /// Sets the handler for call tool requests. + /// + /// The builder instance. + /// The handler. + public static IMcpServerBuilder WithCallToolHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + { + Throw.IfNull(builder); + + builder.Services.Configure(s => s.CallToolHandler = handler); + return builder; + } + + /// + /// Sets the handler for list prompts requests. + /// + /// The builder instance. + /// The handler. + public static IMcpServerBuilder WithListPromptsHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + { + Throw.IfNull(builder); + + builder.Services.Configure(s => s.ListPromptsHandler = handler); + return builder; + } + + /// + /// Sets the handler for get prompt requests. + /// + /// The builder instance. + /// The handler. + public static IMcpServerBuilder WithGetPromptHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + { + Throw.IfNull(builder); + + builder.Services.Configure(s => s.GetPromptHandler = handler); + return builder; + } + + /// + /// Sets the handler for list resources requests. + /// + /// The builder instance. + /// The handler. + public static IMcpServerBuilder WithListResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + { + Throw.IfNull(builder); + + builder.Services.Configure(s => s.ListResourcesHandler = handler); + return builder; + } + + /// + /// Sets the handler for read resources requests. + /// + /// The builder instance. + /// The handler. + public static IMcpServerBuilder WithReadResourceHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + { + Throw.IfNull(builder); + + builder.Services.Configure(s => s.ReadResourceHandler = handler); + return builder; + } + + /// + /// Sets the handler for get completion requests. + /// + /// The builder instance. + /// The handler. + public static IMcpServerBuilder WithGetCompletionHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + { + Throw.IfNull(builder); + + builder.Services.Configure(s => s.GetCompletionHandler = handler); + return builder; + } + + /// + /// Sets the handler for subscribe to resources messages. + /// + /// The builder instance. + /// The handler. + public static IMcpServerBuilder WithSubscribeToResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + { + Throw.IfNull(builder); + + builder.Services.Configure(s => s.SubscribeToResourcesHandler = handler); + return builder; + } + + /// + /// Sets or sets the handler for subscribe to resources messages. + /// + /// The builder instance. + /// The handler. + public static IMcpServerBuilder WithUnsubscribeFromResourcesHandler(this IMcpServerBuilder builder, Func, CancellationToken, Task> handler) + { + Throw.IfNull(builder); + + builder.Services.Configure(s => s.UnsubscribeFromResourcesHandler = handler); + return builder; + } + #endregion + + #region Transports + /// + /// Adds a server transport that uses stdin/stdout for communication. + /// + /// The builder instance. + public static IMcpServerBuilder WithStdioServerTransport(this IMcpServerBuilder builder) + { + Throw.IfNull(builder); + + builder.Services.AddSingleton(); + builder.Services.AddHostedService(); + + builder.Services.AddSingleton(services => + { + ITransport serverTransport = services.GetRequiredService(); + IOptions options = services.GetRequiredService>(); + ILoggerFactory? loggerFactory = services.GetService(); + + return McpServerFactory.Create(serverTransport, options.Value, loggerFactory, services); + }); + + return builder; + } + + /// + /// Adds a server transport that uses SSE via a HttpListener for communication. + /// + /// The builder instance. + public static IMcpServerBuilder WithHttpListenerSseServerTransport(this IMcpServerBuilder builder) + { + Throw.IfNull(builder); + + builder.Services.AddSingleton(); + builder.Services.AddHostedService(); + return builder; + } + #endregion +} diff --git a/src/ModelContextProtocol/Configuration/McpServerConfig.cs b/src/ModelContextProtocol/Configuration/McpServerConfig.cs index 71dcdb42..c8d0a26e 100644 --- a/src/ModelContextProtocol/Configuration/McpServerConfig.cs +++ b/src/ModelContextProtocol/Configuration/McpServerConfig.cs @@ -1,4 +1,4 @@ -namespace ModelContextProtocol.Configuration; +namespace ModelContextProtocol; /// /// Configuration for an MCP server connection. diff --git a/src/ModelContextProtocol/Configuration/McpServerOptionsSetup.cs b/src/ModelContextProtocol/Configuration/McpServerOptionsSetup.cs index f1b91af3..2ef0300e 100644 --- a/src/ModelContextProtocol/Configuration/McpServerOptionsSetup.cs +++ b/src/ModelContextProtocol/Configuration/McpServerOptionsSetup.cs @@ -3,16 +3,18 @@ using Microsoft.Extensions.Options; using ModelContextProtocol.Utils; -namespace ModelContextProtocol.Configuration; +namespace Microsoft.Extensions.DependencyInjection; /// /// Configures the McpServerOptions using addition services from DI. /// /// The server handlers configuration options. /// Tools individually registered. +/// Prompts individually registered. internal sealed class McpServerOptionsSetup( IOptions serverHandlers, - IEnumerable serverTools) : IConfigureOptions + IEnumerable serverTools, + IEnumerable serverPrompts) : IConfigureOptions { /// /// Configures the given McpServerOptions instance by setting server information @@ -39,17 +41,34 @@ public void Configure(McpServerOptions options) // a collection, add to it, otherwise create a new one. We want to maintain the identity // of an existing collection in case someone has provided their own derived type, wants // change notifications, etc. - McpServerToolCollection toolsCollection = options.Capabilities?.Tools?.ToolCollection ?? []; + McpServerPrimitiveCollection toolCollection = options.Capabilities?.Tools?.ToolCollection ?? []; foreach (var tool in serverTools) { - toolsCollection.TryAdd(tool); + toolCollection.TryAdd(tool); } - if (!toolsCollection.IsEmpty) + if (!toolCollection.IsEmpty) { options.Capabilities ??= new(); options.Capabilities.Tools ??= new(); - options.Capabilities.Tools.ToolCollection = toolsCollection; + options.Capabilities.Tools.ToolCollection = toolCollection; + } + + // Collect all of the provided prompts into a prompts collection. If the options already has + // a collection, add to it, otherwise create a new one. We want to maintain the identity + // of an existing collection in case someone has provided their own derived type, wants + // change notifications, etc. + McpServerPrimitiveCollection promptCollection = options.Capabilities?.Prompts?.PromptCollection ?? []; + foreach (var prompt in serverPrompts) + { + promptCollection.TryAdd(prompt); + } + + if (!promptCollection.IsEmpty) + { + options.Capabilities ??= new(); + options.Capabilities.Prompts ??= new(); + options.Capabilities.Prompts.PromptCollection = promptCollection; } // Apply custom server handlers. diff --git a/src/ModelContextProtocol/Configuration/McpServerServiceCollectionExtension.cs b/src/ModelContextProtocol/Configuration/McpServerServiceCollectionExtensions.cs similarity index 72% rename from src/ModelContextProtocol/Configuration/McpServerServiceCollectionExtension.cs rename to src/ModelContextProtocol/Configuration/McpServerServiceCollectionExtensions.cs index 55bbb4ea..4e772a02 100644 --- a/src/ModelContextProtocol/Configuration/McpServerServiceCollectionExtension.cs +++ b/src/ModelContextProtocol/Configuration/McpServerServiceCollectionExtensions.cs @@ -1,16 +1,12 @@ -using ModelContextProtocol.Configuration; -using ModelContextProtocol.Protocol.Transport; -using ModelContextProtocol.Server; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; +using ModelContextProtocol.Server; using Microsoft.Extensions.Options; -namespace ModelContextProtocol; +namespace Microsoft.Extensions.DependencyInjection; /// /// Extension to host an MCP server /// -public static class McpServerServiceCollectionExtension +public static class McpServerServiceCollectionExtensions { /// /// Adds the MCP server to the service collection with default options. diff --git a/src/ModelContextProtocol/Protocol/Messages/NotificationMethods.cs b/src/ModelContextProtocol/Protocol/Messages/NotificationMethods.cs index 8fb50c3b..db7cd677 100644 --- a/src/ModelContextProtocol/Protocol/Messages/NotificationMethods.cs +++ b/src/ModelContextProtocol/Protocol/Messages/NotificationMethods.cs @@ -13,7 +13,7 @@ public static class NotificationMethods /// /// Sent by the server when the list of prompts changes. /// - public const string PromptsListChanged = "notifications/prompts/list_changed"; + public const string PromptListChangedNotification = "notifications/prompts/list_changed"; /// /// Sent by the server when the list of resources changes. diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs index 9d3e24c5..2b4acac9 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientSessionTransport.cs @@ -1,6 +1,5 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using ModelContextProtocol.Configuration; using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Utils; diff --git a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs index c136e162..b6e76844 100644 --- a/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/SseClientTransport.cs @@ -1,5 +1,4 @@ using Microsoft.Extensions.Logging; -using ModelContextProtocol.Configuration; using ModelContextProtocol.Utils; namespace ModelContextProtocol.Protocol.Transport; diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs index 2f3cf420..35c957e5 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientStreamTransport.cs @@ -1,6 +1,5 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using ModelContextProtocol.Configuration; using ModelContextProtocol.Logging; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Utils; diff --git a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs index e1c2ed2d..d2b51b95 100644 --- a/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs +++ b/src/ModelContextProtocol/Protocol/Transport/StdioClientTransport.cs @@ -1,5 +1,4 @@ using Microsoft.Extensions.Logging; -using ModelContextProtocol.Configuration; using ModelContextProtocol.Utils; namespace ModelContextProtocol.Protocol.Transport; diff --git a/src/ModelContextProtocol/Protocol/Types/Capabilities.cs b/src/ModelContextProtocol/Protocol/Types/Capabilities.cs index b3f758ed..f3335778 100644 --- a/src/ModelContextProtocol/Protocol/Types/Capabilities.cs +++ b/src/ModelContextProtocol/Protocol/Types/Capabilities.cs @@ -97,6 +97,18 @@ public class PromptsCapability /// [JsonIgnore] public Func, CancellationToken, Task>? GetPromptHandler { get; set; } + + /// Gets or sets a collection of prompts served by the server. + /// + /// Prompts will specified via augment the and + /// , if provided. ListPrompts requests will output information about every prompt + /// in and then also any tools output by , if it's + /// non-. GetPrompt requests will first check for the prompt + /// being requested, and if the tool is not found in the , any specified + /// will be invoked as a fallback. + /// + [JsonIgnore] + public McpServerPrimitiveCollection? PromptCollection { get; set; } } /// @@ -182,5 +194,5 @@ public class ToolsCapability /// will be invoked as a fallback. /// [JsonIgnore] - public McpServerToolCollection? ToolCollection { get; set; } + public McpServerPrimitiveCollection? ToolCollection { get; set; } } \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs b/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs new file mode 100644 index 00000000..cc2bae5d --- /dev/null +++ b/src/ModelContextProtocol/Server/AIFunctionMcpServerPrompt.cs @@ -0,0 +1,256 @@ +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Utils; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; +using System.Text.Json; + +namespace ModelContextProtocol.Server; + +/// Provides an that's implemented via an . +internal sealed class AIFunctionMcpServerPrompt : McpServerPrompt +{ + /// Key used temporarily for flowing request context into an AIFunction. + /// This will be replaced with use of AIFunctionArguments.Context. + internal const string RequestContextKey = "__temporary_RequestContext"; + + /// + /// Creates an instance for a method, specified via a instance. + /// + public static new AIFunctionMcpServerPrompt Create( + Delegate method, + McpServerPromptCreateOptions? options) + { + Throw.IfNull(method); + + options = DeriveOptions(method.Method, options); + + return Create(method.Method, method.Target, options); + } + + /// + /// Creates an instance for a method, specified via a instance. + /// + public static new AIFunctionMcpServerPrompt Create( + MethodInfo method, + object? target, + McpServerPromptCreateOptions? options) + { + Throw.IfNull(method); + + // TODO: Once this repo consumes a new build of Microsoft.Extensions.AI containing + // https://github.com/dotnet/extensions/pull/6158, + // https://github.com/dotnet/extensions/pull/6162, and + // https://github.com/dotnet/extensions/pull/6175, switch over to using the real + // AIFunctionFactory, delete the TemporaryXx types, and fix-up the mechanism by + // which the arguments are passed. + + options = DeriveOptions(method, options); + + return Create( + TemporaryAIFunctionFactory.Create(method, target, CreateAIFunctionFactoryOptions(method, options)), + options); + } + + /// + /// Creates an instance for a method, specified via a instance. + /// + public static new AIFunctionMcpServerPrompt Create( + MethodInfo method, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, + McpServerPromptCreateOptions? options) + { + Throw.IfNull(method); + + options = DeriveOptions(method, options); + + return Create( + TemporaryAIFunctionFactory.Create(method, targetType, CreateAIFunctionFactoryOptions(method, options)), + options); + } + + private static TemporaryAIFunctionFactoryOptions CreateAIFunctionFactoryOptions( + MethodInfo method, McpServerPromptCreateOptions? options) => + new() + { + Name = options?.Name ?? method.GetCustomAttribute()?.Name, + Description = options?.Description, + MarshalResult = static (result, _, cancellationToken) => Task.FromResult(result), + ConfigureParameterBinding = pi => + { + if (pi.ParameterType == typeof(RequestContext)) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => GetRequestContext(args), + }; + } + + if (pi.ParameterType == typeof(IMcpServer)) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => GetRequestContext(args)?.Server, + }; + } + + // We assume that if the services used to create the prompt support a particular type, + // so too do the services associated with the server. This is the same basic assumption + // made in ASP.NET. + if (options?.Services is { } services && + services.GetService() is { } ispis && + ispis.IsService(pi.ParameterType)) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => + GetRequestContext(args)?.Server?.Services?.GetService(pi.ParameterType) ?? + (pi.HasDefaultValue ? null : + throw new ArgumentException("No service of the requested type was found.")), + }; + } + + if (pi.GetCustomAttribute() is { } keyedAttr) + { + return new() + { + ExcludeFromSchema = true, + BindParameter = (pi, args) => + (GetRequestContext(args)?.Server?.Services as IKeyedServiceProvider)?.GetKeyedService(pi.ParameterType, keyedAttr.Key) ?? + (pi.HasDefaultValue ? null : + throw new ArgumentException("No service of the requested type was found.")), + }; + } + + return default; + + static RequestContext? GetRequestContext(IReadOnlyDictionary args) + { + if (args.TryGetValue(RequestContextKey, out var orc) && + orc is RequestContext requestContext) + { + return requestContext; + } + + return null; + } + }, + }; + + /// Creates an that wraps the specified . + public static new AIFunctionMcpServerPrompt Create(AIFunction function, McpServerPromptCreateOptions? options) + { + Throw.IfNull(function); + + List args = []; + if (function.JsonSchema.TryGetProperty("properties", out JsonElement properties)) + { + foreach (var param in properties.EnumerateObject()) + { + args.Add(new() + { + Name = param.Name, + Description = param.Value.TryGetProperty("description", out JsonElement description) ? description.GetString() : null, + Required = param.Value.TryGetProperty("required", out JsonElement required) && required.GetBoolean(), + }); + } + } + + Prompt prompt = new() + { + Name = options?.Name ?? function.Name, + Description = options?.Description ?? function.Description, + Arguments = args, + }; + + return new AIFunctionMcpServerPrompt(function, prompt); + } + + private static McpServerPromptCreateOptions? DeriveOptions(MethodInfo method, McpServerPromptCreateOptions? options) + { + McpServerPromptCreateOptions newOptions = options?.Clone() ?? new(); + + if (method.GetCustomAttribute() is { } attr) + { + newOptions.Name ??= attr.Name; + } + + return newOptions; + } + + /// Gets the wrapped by this prompt. + internal AIFunction AIFunction { get; } + + /// Initializes a new instance of the class. + private AIFunctionMcpServerPrompt(AIFunction function, Prompt prompt) + { + AIFunction = function; + ProtocolPrompt = prompt; + } + + /// + public override string ToString() => AIFunction.ToString(); + + /// + public override Prompt ProtocolPrompt { get; } + + /// + public override async Task GetAsync( + RequestContext request, CancellationToken cancellationToken = default) + { + Throw.IfNull(request); + + cancellationToken.ThrowIfCancellationRequested(); + + // TODO: Once we shift to the real AIFunctionFactory, the request should be passed via AIFunctionArguments.Context. + Dictionary arguments = request.Params?.Arguments is IDictionary existingArgs ? + new(existingArgs) : + []; + arguments[RequestContextKey] = request; + + object? result = await AIFunction.InvokeAsync(arguments, cancellationToken).ConfigureAwait(false); + + return result switch + { + GetPromptResult getPromptResult => getPromptResult, + + string text => new() + { + Description = ProtocolPrompt.Description, + Messages = [new() { Role = Role.User, Content = new() { Text = text, Type = "text" } }], + }, + + PromptMessage promptMessage => new() + { + Description = ProtocolPrompt.Description, + Messages = [promptMessage], + }, + + IEnumerable promptMessages => new() + { + Description = ProtocolPrompt.Description, + Messages = [.. promptMessages], + }, + + ChatMessage chatMessage => new() + { + Description = ProtocolPrompt.Description, + Messages = [.. chatMessage.ToPromptMessages()], + }, + + IEnumerable chatMessages => new() + { + Description = ProtocolPrompt.Description, + Messages = [.. chatMessages.SelectMany(chatMessage => chatMessage.ToPromptMessages())], + }, + + null => throw new InvalidOperationException($"Null result returned from prompt function."), + + _ => throw new InvalidOperationException($"Unknown result type '{result.GetType()}' returned from prompt function."), + }; + } +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/DelegatingMcpServerPrompt.cs b/src/ModelContextProtocol/Server/DelegatingMcpServerPrompt.cs new file mode 100644 index 00000000..7cf052eb --- /dev/null +++ b/src/ModelContextProtocol/Server/DelegatingMcpServerPrompt.cs @@ -0,0 +1,34 @@ +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Utils; + +namespace ModelContextProtocol.Server; + +/// Provides an that delegates all operations to an inner . +/// +/// This is recommended as a base type when building prompts that can be chained around an underlying . +/// The default implementation simply passes each call to the inner prompt instance. +/// +public abstract class DelegatingMcpServerPrompt : McpServerPrompt +{ + private readonly McpServerPrompt _innerPrompt; + + /// Initializes a new instance of the class around the specified . + /// The inner prompt wrapped by this delegating prompt. + protected DelegatingMcpServerPrompt(McpServerPrompt innerPrompt) + { + Throw.IfNull(innerPrompt); + _innerPrompt = innerPrompt; + } + + /// + public override Prompt ProtocolPrompt => _innerPrompt.ProtocolPrompt; + + /// + public override Task GetAsync( + RequestContext request, + CancellationToken cancellationToken = default) => + _innerPrompt.GetAsync(request, cancellationToken); + + /// + public override string ToString() => _innerPrompt.ToString(); +} diff --git a/src/ModelContextProtocol/Server/IMcpServerPrimitive.cs b/src/ModelContextProtocol/Server/IMcpServerPrimitive.cs new file mode 100644 index 00000000..d0676045 --- /dev/null +++ b/src/ModelContextProtocol/Server/IMcpServerPrimitive.cs @@ -0,0 +1,8 @@ +namespace ModelContextProtocol.Server; + +/// Represents an MCP server primitive, like a tool or a prompt. +public interface IMcpServerPrimitive +{ + /// Gets the name of the primitive. + string Name { get; } +} diff --git a/src/ModelContextProtocol/Server/McpServer.cs b/src/ModelContextProtocol/Server/McpServer.cs index 161736f7..d9e456c7 100644 --- a/src/ModelContextProtocol/Server/McpServer.cs +++ b/src/ModelContextProtocol/Server/McpServer.cs @@ -14,6 +14,7 @@ internal sealed class McpServer : McpJsonRpcEndpoint, IMcpServer { private readonly IServerTransport? _serverTransport; private readonly EventHandler? _toolsChangedDelegate; + private readonly EventHandler? _promptsChangedDelegate; private ITransport? _sessionTransport; private string _endpointName; @@ -77,6 +78,13 @@ private McpServer(McpServerOptions options, ILoggerFactory? loggerFactory, IServ Method = NotificationMethods.ToolListChangedNotification, }); }; + _promptsChangedDelegate = delegate + { + _ = SendMessageAsync(new JsonRpcNotification() + { + Method = NotificationMethods.PromptListChangedNotification, + }); + }; AddNotificationHandler(NotificationMethods.InitializedNotification, _ => { @@ -85,6 +93,11 @@ private McpServer(McpServerOptions options, ILoggerFactory? loggerFactory, IServ tools.Changed += _toolsChangedDelegate; } + if (ServerOptions.Capabilities?.Prompts?.PromptCollection is { } prompts) + { + prompts.Changed += _promptsChangedDelegate; + } + return Task.CompletedTask; }); @@ -160,6 +173,11 @@ public override async ValueTask DisposeUnsynchronizedAsync() tools.Changed -= _toolsChangedDelegate; } + if (ServerOptions.Capabilities?.Prompts?.PromptCollection is { } prompts) + { + prompts.Changed -= _promptsChangedDelegate; + } + try { await base.DisposeUnsynchronizedAsync().ConfigureAwait(false); @@ -253,15 +271,97 @@ private void SetResourcesHandler(McpServerOptions options) private void SetPromptsHandler(McpServerOptions options) { - if (options.Capabilities?.Prompts is not { } promptsCapability) + PromptsCapability? promptsCapability = options.Capabilities?.Prompts; + var listPromptsHandler = promptsCapability?.ListPromptsHandler; + var getPromptHandler = promptsCapability?.GetPromptHandler; + var prompts = promptsCapability?.PromptCollection; + + if (listPromptsHandler is null != getPromptHandler is null) { - return; + throw new McpServerException("ListPrompts and GetPrompt handlers should be specified together."); } - if (promptsCapability.ListPromptsHandler is not { } listPromptsHandler || - promptsCapability.GetPromptHandler is not { } getPromptHandler) + // Handle tools provided via DI. + if (prompts is { IsEmpty: false }) + { + var originalListPromptsHandler = listPromptsHandler; + var originalGetPromptHandler = getPromptHandler; + + // Synthesize the handlers, making sure a ToolsCapability is specified. + listPromptsHandler = async (request, cancellationToken) => + { + ListPromptsResult result = new(); + foreach (McpServerPrompt prompt in prompts) + { + result.Prompts.Add(prompt.ProtocolPrompt); + } + + if (originalListPromptsHandler is not null) + { + string? nextCursor = null; + do + { + ListPromptsResult extraResults = await originalListPromptsHandler(request, cancellationToken).ConfigureAwait(false); + result.Prompts.AddRange(extraResults.Prompts); + + nextCursor = extraResults.NextCursor; + if (nextCursor is not null) + { + request = request with { Params = new() { Cursor = nextCursor } }; + } + } + while (nextCursor is not null); + } + + return result; + }; + + getPromptHandler = (request, cancellationToken) => + { + if (request.Params is null || + !prompts.TryGetPrimitive(request.Params.Name, out var prompt)) + { + if (originalGetPromptHandler is not null) + { + return originalGetPromptHandler(request, cancellationToken); + } + + throw new McpServerException($"Unknown prompt '{request.Params?.Name}'"); + } + + return prompt.GetAsync(request, cancellationToken); + }; + + ServerCapabilities = new() + { + Experimental = options.Capabilities?.Experimental, + Logging = options.Capabilities?.Logging, + Tools = options.Capabilities?.Tools, + Resources = options.Capabilities?.Resources, + Prompts = new() + { + ListPromptsHandler = listPromptsHandler, + GetPromptHandler = getPromptHandler, + PromptCollection = prompts, + ListChanged = true, + } + }; + } + else { - throw new McpServerException("Prompts capability was enabled, but ListPrompts and/or GetPrompt handlers were not specified."); + ServerCapabilities = options.Capabilities; + + if (promptsCapability is null) + { + // No prompts, and no prompts capability was declared, so nothing to do. + return; + } + + // Make sure the handlers are provided if the capability is enabled. + if (listPromptsHandler is null || getPromptHandler is null) + { + throw new McpServerException("ListPrompts and/or GetPrompt handlers were not specified but the Prompts capability was enabled."); + } } SetRequestHandler(RequestMethods.PromptsList, (request, ct) => listPromptsHandler(new(this, request), ct)); @@ -318,7 +418,7 @@ private void SetToolsHandler(McpServerOptions options) callToolHandler = (request, cancellationToken) => { if (request.Params is null || - !tools.TryGetTool(request.Params.Name, out var tool)) + !tools.TryGetPrimitive(request.Params.Name, out var tool)) { if (originalCallToolHandler is not null) { diff --git a/src/ModelContextProtocol/Server/McpServerPrimitiveCollection.cs b/src/ModelContextProtocol/Server/McpServerPrimitiveCollection.cs new file mode 100644 index 00000000..b6c1f170 --- /dev/null +++ b/src/ModelContextProtocol/Server/McpServerPrimitiveCollection.cs @@ -0,0 +1,163 @@ +using ModelContextProtocol.Utils; +using System.Collections; +using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; + +namespace ModelContextProtocol.Server; + +/// Provides a thread-safe collection of instances, indexed by their names. +/// Specifies the type of primitive stored in the collection. +public class McpServerPrimitiveCollection : ICollection, IReadOnlyCollection + where T : IMcpServerPrimitive +{ + /// Concurrent dictionary of primitives, indexed by their names. + private readonly ConcurrentDictionary _primitives = []; + + /// + /// Initializes a new instance of the class. + /// + public McpServerPrimitiveCollection() + { + } + + /// Occurs when the collection is changed. + /// + /// By default, this is raised when a primitive is added or removed. However, a derived implementation + /// may raise this event for other reasons, such as when a primitive is modified. + /// + public event EventHandler? Changed; + + /// Gets the number of primitives in the collection. + public int Count => _primitives.Count; + + /// Gets whether there are any primitives in the collection. + public bool IsEmpty => _primitives.IsEmpty; + + /// Raises if there are registered handlers. + protected void RaiseChanged() => Changed?.Invoke(this, EventArgs.Empty); + + /// Gets the with the specified from the collection. + /// The name of the primitive to retrieve. + /// The with the specified name. + /// is . + /// An primitive with the specified name does not exist in the collection. + public T this[string name] + { + get + { + Throw.IfNull(name); + return _primitives[name]; + } + } + + /// Clears all primitives from the collection. + public virtual void Clear() + { + _primitives.Clear(); + RaiseChanged(); + } + + /// Adds the specified to the collection. + /// The primitive to be added. + /// is . + /// A primitive with the same name as already exists in the collection. + public void Add(T primitive) + { + if (!TryAdd(primitive)) + { + throw new ArgumentException($"A primitive with the same name '{primitive.Name}' already exists in the collection.", nameof(primitive)); + } + } + + /// Adds the specified to the collection. + /// The primitive to be added. + /// if the primitive was added; otherwise, . + /// is . + public virtual bool TryAdd(T primitive) + { + Throw.IfNull(primitive); + + bool added = _primitives.TryAdd(primitive.Name, primitive); + if (added) + { + RaiseChanged(); + } + + return added; + } + + /// Removes the specified primitivefrom the collection. + /// The primitive to be removed from the collection. + /// + /// if the primitive was found in the collection and removed; otherwise, if it couldn't be found. + /// + /// is . + public virtual bool Remove(T primitive) + { + Throw.IfNull(primitive); + + bool removed = ((ICollection>)_primitives).Remove(new(primitive.Name, primitive)); + if (removed) + { + RaiseChanged(); + } + + return removed; + } + + /// Attempts to get the primitive with the specified name from the collection. + /// The name of the primitive to retrieve. + /// The primitive, if found; otherwise, . + /// + /// if the primitive was found in the collection and return; otherwise, if it couldn't be found. + /// + /// is . + public virtual bool TryGetPrimitive(string name, [NotNullWhen(true)] out T? primitive) + { + Throw.IfNull(name); + return _primitives.TryGetValue(name, out primitive); + } + + /// Checks if a specific primitive is present in the collection of primitives. + /// The primitive to search for in the collection. + /// if the primitive was found in the collection and return; otherwise, if it couldn't be found. + /// is . + public virtual bool Contains(T primitive) + { + Throw.IfNull(primitive); + return ((ICollection>)_primitives).Contains(new(primitive.Name, primitive)); + } + + /// Gets the names of all of the primitives in the collection. + public virtual ICollection PrimitiveNames => _primitives.Keys; + + /// Creates an array containing all of the primitives in the collection. + /// An array containing all of the primitives in the collection. + public virtual T[] ToArray() => _primitives.Values.ToArray(); + + /// + public virtual void CopyTo(T[] array, int arrayIndex) + { + Throw.IfNull(array); + + _primitives.Values.CopyTo(array, arrayIndex); + } + + /// + public virtual IEnumerator GetEnumerator() + { + foreach (var entry in _primitives) + { + yield return entry.Value; + } + } + + /// + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + /// + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + + /// + bool ICollection.IsReadOnly => false; +} \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/McpServerPrompt.cs b/src/ModelContextProtocol/Server/McpServerPrompt.cs new file mode 100644 index 00000000..6936bf29 --- /dev/null +++ b/src/ModelContextProtocol/Server/McpServerPrompt.cs @@ -0,0 +1,95 @@ +using Microsoft.Extensions.AI; +using ModelContextProtocol.Protocol.Types; +using System.Diagnostics.CodeAnalysis; +using System.Reflection; + +namespace ModelContextProtocol.Server; + +/// Represents an invocable prompt used by Model Context Protocol servers. +public abstract class McpServerPrompt : IMcpServerPrimitive +{ + /// Initializes a new instance of the class. + protected McpServerPrompt() + { + } + + /// Gets the protocol type for this instance. + public abstract Prompt ProtocolPrompt { get; } + + /// Invokes the . + /// The request information resulting in the invocation of this tool. + /// The to monitor for cancellation requests. The default is . + /// The call response from invoking the tool. + /// is . + public abstract Task GetAsync( + RequestContext request, + CancellationToken cancellationToken = default); + + /// + /// Creates an instance for a method, specified via a instance. + /// + /// The method to be represented via the created . + /// Optional options used in the creation of the to control its behavior. + /// The created for invoking . + /// is . + public static McpServerPrompt Create( + Delegate method, + McpServerPromptCreateOptions? options = null) => + AIFunctionMcpServerPrompt.Create(method, options); + + /// + /// Creates an instance for a method, specified via a instance. + /// + /// The method to be represented via the created . + /// The instance if is an instance method; otherwise, . + /// Optional options used in the creation of the to control its behavior. + /// The created for invoking . + /// is . + /// is an instance method but is . + public static McpServerPrompt Create( + MethodInfo method, + object? target = null, + McpServerPromptCreateOptions? options = null) => + AIFunctionMcpServerPrompt.Create(method, target, options); + + /// + /// Creates an instance for a method, specified via an for + /// and instance method, along with a representing the type of the target object to + /// instantiate each time the method is invoked. + /// + /// The instance method to be represented via the created . + /// + /// The to construct an instance of on which to invoke when + /// the resulting is invoked. If services are provided, + /// ActivatorUtilities.CreateInstance will be used to construct the instance using those services; otherwise, + /// is used, utilizing the type's public parameterless constructor. + /// If an instance can't be constructed, an exception is thrown during the function's invocation. + /// + /// Optional options used in the creation of the to control its behavior. + /// The created for invoking . + /// is . + public static McpServerPrompt Create( + MethodInfo method, + [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type targetType, + McpServerPromptCreateOptions? options = null) => + AIFunctionMcpServerPrompt.Create(method, targetType, options); + + /// Creates an that wraps the specified . + /// The function to wrap. + /// Optional options used in the creation of the to control its behavior. + /// is . + /// + /// Unlike the other overloads of Create, the created by + /// does not provide all of the special parameter handling for MCP-specific concepts, like . + /// + public static McpServerPrompt Create( + AIFunction function, + McpServerPromptCreateOptions? options = null) => + AIFunctionMcpServerPrompt.Create(function, options); + + /// + public override string ToString() => ProtocolPrompt.Name; + + /// + string IMcpServerPrimitive.Name => ProtocolPrompt.Name; +} diff --git a/src/ModelContextProtocol/Server/McpServerPromptAttribute.cs b/src/ModelContextProtocol/Server/McpServerPromptAttribute.cs new file mode 100644 index 00000000..0a0da56f --- /dev/null +++ b/src/ModelContextProtocol/Server/McpServerPromptAttribute.cs @@ -0,0 +1,19 @@ +namespace ModelContextProtocol.Server; + +/// +/// Used to indicate that a method should be considered an MCP prompt and describe it. +/// +[AttributeUsage(AttributeTargets.Method)] +public sealed class McpServerPromptAttribute : Attribute +{ + /// + /// Initializes a new instance of the class. + /// + public McpServerPromptAttribute() + { + } + + /// Gets the name of the prompt. + /// If , the method name will be used. + public string? Name { get; set; } +} diff --git a/src/ModelContextProtocol/Server/McpServerPromptCreateOptions.cs b/src/ModelContextProtocol/Server/McpServerPromptCreateOptions.cs new file mode 100644 index 00000000..31664707 --- /dev/null +++ b/src/ModelContextProtocol/Server/McpServerPromptCreateOptions.cs @@ -0,0 +1,45 @@ +using System.ComponentModel; + +namespace ModelContextProtocol.Server; + +/// Provides options for controlling the creation of an . +public sealed class McpServerPromptCreateOptions +{ + /// + /// Gets or sets optional services used in the construction of the . + /// + /// + /// These services will be used to determine which parameters should be satisifed from dependency injection; what services + /// are satisfied via this provider should match what's satisfied via the provider passed in at invocation time. + /// + public IServiceProvider? Services { get; set; } + + /// + /// Gets or sets the name to use for the . + /// + /// + /// If , but an is applied to the method, + /// the name from the attribute will be used. If that's not present, a name based on the method's name will be used. + /// + public string? Name { get; set; } + + /// + /// Gets or set the description to use for the . + /// + /// + /// If , but a is applied to the method, + /// the description from that attribute will be used. + /// + public string? Description { get; set; } + + /// + /// Creates a shallow clone of the current instance. + /// + internal McpServerPromptCreateOptions Clone() => + new McpServerPromptCreateOptions() + { + Services = Services, + Name = Name, + Description = Description, + }; +} diff --git a/src/ModelContextProtocol/Server/McpServerPromptTypeAttribute.cs b/src/ModelContextProtocol/Server/McpServerPromptTypeAttribute.cs new file mode 100644 index 00000000..be9e29a0 --- /dev/null +++ b/src/ModelContextProtocol/Server/McpServerPromptTypeAttribute.cs @@ -0,0 +1,14 @@ +using Microsoft.Extensions.DependencyInjection; + +namespace ModelContextProtocol.Server; + +/// +/// Used to attribute a type containing methods that should be exposed as MCP prompts. +/// +/// +/// This is primarily relevant to methods that scan types in an assembly looking for methods +/// to expose, such as . It is not +/// necessary to attribute types explicitly provided to a method like . +/// +[AttributeUsage(AttributeTargets.Class)] +public sealed class McpServerPromptTypeAttribute : Attribute; diff --git a/src/ModelContextProtocol/Server/McpServerTool.cs b/src/ModelContextProtocol/Server/McpServerTool.cs index d1003f57..abe7f99e 100644 --- a/src/ModelContextProtocol/Server/McpServerTool.cs +++ b/src/ModelContextProtocol/Server/McpServerTool.cs @@ -6,7 +6,7 @@ namespace ModelContextProtocol.Server; /// Represents an invocable tool used by Model Context Protocol clients and servers. -public abstract class McpServerTool +public abstract class McpServerTool : IMcpServerPrimitive { /// Initializes a new instance of the class. protected McpServerTool() @@ -89,4 +89,7 @@ public static McpServerTool Create( /// public override string ToString() => ProtocolTool.Name; + + /// + string IMcpServerPrimitive.Name => ProtocolTool.Name; } diff --git a/src/ModelContextProtocol/Server/McpServerToolAttribute.cs b/src/ModelContextProtocol/Server/McpServerToolAttribute.cs index 5320fafc..f4f1d394 100644 --- a/src/ModelContextProtocol/Server/McpServerToolAttribute.cs +++ b/src/ModelContextProtocol/Server/McpServerToolAttribute.cs @@ -19,7 +19,7 @@ public sealed class McpServerToolAttribute : Attribute internal bool? _readOnly; /// - /// Initializes a new instance of the class. + /// Initializes a new instance of the class. /// public McpServerToolAttribute() { diff --git a/src/ModelContextProtocol/Server/McpServerToolCollection.cs b/src/ModelContextProtocol/Server/McpServerToolCollection.cs deleted file mode 100644 index f5234aa9..00000000 --- a/src/ModelContextProtocol/Server/McpServerToolCollection.cs +++ /dev/null @@ -1,164 +0,0 @@ -using ModelContextProtocol.Utils; -using System.Collections; -using System.Collections.Concurrent; -using System.Diagnostics.CodeAnalysis; - -namespace ModelContextProtocol.Server; - -/// Provides a thread-safe collection of instances, indexed by their names. -public class McpServerToolCollection : ICollection, IReadOnlyCollection -{ - /// Concurrent dictionary of tools, indexed by their names. - private readonly ConcurrentDictionary _tools = []; - - /// - /// Initializes a new instance of the class. - /// - public McpServerToolCollection() - { - } - - /// Occurs when the collection is changed. - /// - /// By default, this is raised when a tool is added or removed. However, a derived implementation - /// may raise this event for other reasons, such as when a tool is modified. - /// - public event EventHandler? Changed; - - /// Gets the number of tools in the collection. - public int Count => _tools.Count; - - /// Gets whether there are any tools in the collection. - public bool IsEmpty => _tools.IsEmpty; - - /// Raises if there are registered handlers. - protected void RaiseChanged() => Changed?.Invoke(this, EventArgs.Empty); - - /// Gets the with the specified from the collection. - /// The name of the tool to retrieve. - /// The with the specified name. - /// is . - /// A tool with the specified name does not exist in the collection. - public McpServerTool this[string name] - { - get - { - Throw.IfNull(name); - return _tools[name]; - } - } - - /// Clears all tools from the collection. - public virtual void Clear() - { - _tools.Clear(); - RaiseChanged(); - } - - /// Adds the specified to the collection. - /// The tool to be added. - /// is . - /// A tool with the same name as already exists in the collection. - public void Add(McpServerTool tool) - { - if (!TryAdd(tool)) - { - throw new ArgumentException($"A tool with the same name '{tool.ProtocolTool.Name}' already exists in the collection.", nameof(tool)); - } - } - - /// Adds the specified to the collection. - /// The tool to be added. - /// if the tool was added; otherwise, . - /// is . - public virtual bool TryAdd(McpServerTool tool) - { - Throw.IfNull(tool); - - bool added = _tools.TryAdd(tool.ProtocolTool.Name, tool); - if (added) - { - RaiseChanged(); - } - - return added; - } - - /// Removes the specified toolfrom the collection. - /// The tool to be removed from the collection. - /// - /// if the tool was found in the collection and removed; otherwise, if it couldn't be found. - /// - /// is . - public virtual bool Remove(McpServerTool tool) - { - Throw.IfNull(tool); - - bool removed = ((ICollection>)_tools).Remove(new(tool.ProtocolTool.Name, tool)); - if (removed) - { - RaiseChanged(); - } - - return removed; - } - - /// Attempts to get the tool with the specified name from the collection. - /// The name of the tool to retrieve. - /// The tool, if found; otherwise, . - /// - /// if the tool was found in the collection and return; otherwise, if it couldn't be found. - /// - /// is . - public virtual bool TryGetTool(string name, [NotNullWhen(true)] out McpServerTool? tool) - { - Throw.IfNull(name); - return _tools.TryGetValue(name, out tool); - } - - /// Checks if a specific tool is present in the collection of tools. - /// The tool to search for in the collection. - /// if the tool was found in the collection and return; otherwise, if it couldn't be found. - /// is . - public virtual bool Contains(McpServerTool tool) - { - Throw.IfNull(tool); - return ((ICollection>)_tools).Contains(new(tool.ProtocolTool.Name, tool)); - } - - /// Gets the names of all of the tools in the collection. - public virtual ICollection ToolNames => _tools.Keys; - - /// Creates an array containing all of the tools in the collection. - /// An array containing all of the tools in the collection. - public virtual McpServerTool[] ToArray() => _tools.Values.ToArray(); - - /// - public virtual void CopyTo(McpServerTool[] array, int arrayIndex) - { - Throw.IfNull(array); - - foreach (var entry in _tools) - { - array[arrayIndex++] = entry.Value; - } - } - - /// - public virtual IEnumerator GetEnumerator() - { - foreach (var entry in _tools) - { - yield return entry.Value; - } - } - - /// - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - - /// - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - - /// - bool ICollection.IsReadOnly => false; -} \ No newline at end of file diff --git a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs index 67e7a99b..c79a09ed 100644 --- a/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs +++ b/src/ModelContextProtocol/Server/TemporaryAIFunctionFactory.cs @@ -201,9 +201,16 @@ public static ReflectionAIFunction Build(MethodInfo method, object? target, Temp throw new ArgumentException("Open generic methods are not supported", nameof(method)); } - if (!method.IsStatic && target is null) + if (method.IsStatic) + { + if (target is not null) + { + throw new ArgumentException("The specified method is a static method but the specified target is non-null.", nameof(target)); + } + } + else if (target is null) { - throw new ArgumentNullException("Target must not be null for an instance method.", nameof(target)); + throw new ArgumentNullException(nameof(target), "The specified method is an instance method but the specified target is null."); } var functionDescriptor = ReflectionAIFunctionDescriptor.GetOrCreate(method, options); diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index 51583583..710679e9 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -1,7 +1,5 @@ using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Options; using ModelContextProtocol.Client; -using ModelContextProtocol.Configuration; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Server; using ModelContextProtocol.Tests.Transport; diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs index f208d615..f6b8f2b1 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientFactoryTests.cs @@ -1,6 +1,5 @@ using System.Threading.Channels; using ModelContextProtocol.Client; -using ModelContextProtocol.Configuration; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs index b7de29d5..a77ae6b8 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTestFixture.cs @@ -1,5 +1,4 @@ using ModelContextProtocol.Client; -using ModelContextProtocol.Configuration; using ModelContextProtocol.Protocol.Transport; using Microsoft.Extensions.Logging; diff --git a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs index 76d91f65..b3cdfea3 100644 --- a/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/ClientIntegrationTests.cs @@ -4,7 +4,6 @@ using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Protocol.Messages; using System.Text.Json; -using ModelContextProtocol.Configuration; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Tests.Utils; using System.Text.Encodings.Web; diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs index 9a0fe72b..22e53817 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsHandlerTests.cs @@ -1,4 +1,3 @@ -using ModelContextProtocol.Configuration; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Server; using Microsoft.Extensions.DependencyInjection; diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs new file mode 100644 index 00000000..846ceebc --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -0,0 +1,259 @@ +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using ModelContextProtocol.Client; +using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Server; +using ModelContextProtocol.Tests.Transport; +using ModelContextProtocol.Tests.Utils; +using System.ComponentModel; +using System.IO.Pipelines; +using System.Threading.Channels; + +namespace ModelContextProtocol.Tests.Configuration; + +public class McpServerBuilderExtensionsPromptsTests : LoggedTest, IAsyncDisposable +{ + private readonly Pipe _clientToServerPipe = new(); + private readonly Pipe _serverToClientPipe = new(); + private readonly ServiceProvider _serviceProvider; + private readonly IMcpServerBuilder _builder; + private readonly CancellationTokenSource _cts; + private readonly Task _serverTask; + + public McpServerBuilderExtensionsPromptsTests(ITestOutputHelper testOutputHelper) + : base(testOutputHelper) + { + ServiceCollection sc = new(); + sc.AddSingleton(LoggerFactory); + _builder = sc.AddMcpServer().WithStdioServerTransport().WithPrompts(); + // Call WithStdioServerTransport to get the IMcpServer registration, then overwrite default transport with a pipe transport. + sc.AddSingleton(new StdioServerTransport("TestServer", _clientToServerPipe.Reader.AsStream(), _serverToClientPipe.Writer.AsStream(), LoggerFactory)); + sc.AddSingleton(new ObjectWithId()); + _serviceProvider = sc.BuildServiceProvider(); + + var server = _serviceProvider.GetRequiredService(); + _cts = CancellationTokenSource.CreateLinkedTokenSource(TestContext.Current.CancellationToken); + _serverTask = server.RunAsync(cancellationToken: _cts.Token); + } + + public async ValueTask DisposeAsync() + { + await _cts.CancelAsync(); + + _clientToServerPipe.Writer.Complete(); + _serverToClientPipe.Writer.Complete(); + + await _serverTask; + + await _serviceProvider.DisposeAsync(); + _cts.Dispose(); + Dispose(); + } + + private async Task CreateMcpClientForServer() + { + var serverStdinWriter = new StreamWriter(_clientToServerPipe.Writer.AsStream()); + var serverStdoutReader = new StreamReader(_serverToClientPipe.Reader.AsStream()); + + var serverConfig = new McpServerConfig() + { + Id = "TestServer", + Name = "TestServer", + TransportType = "ignored", + }; + + return await McpClientFactory.CreateAsync( + serverConfig, + createTransportFunc: (_, _) => new StreamClientTransport(serverStdinWriter, serverStdoutReader, LoggerFactory), + loggerFactory: LoggerFactory, + cancellationToken: TestContext.Current.CancellationToken); + } + + [Fact] + public void Adds_Prompts_To_Server() + { + var serverOptions = _serviceProvider.GetRequiredService>().Value; + var prompts = serverOptions?.Capabilities?.Prompts?.PromptCollection; + Assert.NotNull(prompts); + Assert.NotEmpty(prompts); + } + + [Fact] + public async Task Can_List_And_Call_Registered_Prompts() + { + IMcpClient client = await CreateMcpClientForServer(); + + var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); + Assert.Equal(3, prompts.Count); + + var prompt = prompts.First(t => t.Name == nameof(SimplePrompts.ReturnsChatMessages)); + Assert.Equal("Returns chat messages", prompt.Description); + + var result = await prompt.GetAsync(new Dictionary() { ["message"] = "hello" }, TestContext.Current.CancellationToken); + var chatMessages = result.ToChatMessages(); + + Assert.NotNull(chatMessages); + Assert.NotEmpty(chatMessages); + Assert.Equal(2, chatMessages.Count); + Assert.Equal("The prompt is: hello", chatMessages[0].Text); + Assert.Equal("Summarize.", chatMessages[1].Text); + } + + [Fact] + public async Task Can_Be_Notified_Of_Prompt_Changes() + { + IMcpClient client = await CreateMcpClientForServer(); + + var prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); + Assert.Equal(3, prompts.Count); + + Channel listChanged = Channel.CreateUnbounded(); + client.AddNotificationHandler("notifications/prompts/list_changed", notification => + { + listChanged.Writer.TryWrite(notification); + return Task.CompletedTask; + }); + + var notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); + Assert.False(notificationRead.IsCompleted); + + var serverOptions = _serviceProvider.GetRequiredService>().Value; + var serverPrompts = serverOptions.Capabilities?.Prompts?.PromptCollection; + Assert.NotNull(serverPrompts); + + var newPrompt = McpServerPrompt.Create([McpServerPrompt(Name = "NewPrompt")] () => "42"); + serverPrompts.Add(newPrompt); + await notificationRead; + + prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); + Assert.Equal(4, prompts.Count); + Assert.Contains(prompts, t => t.Name == "NewPrompt"); + + notificationRead = listChanged.Reader.ReadAsync(TestContext.Current.CancellationToken); + Assert.False(notificationRead.IsCompleted); + serverPrompts.Remove(newPrompt); + await notificationRead; + + prompts = await client.ListPromptsAsync(TestContext.Current.CancellationToken); + Assert.Equal(3, prompts.Count); + Assert.DoesNotContain(prompts, t => t.Name == "NewPrompt"); + } + + [Fact] + public async Task Throws_When_Prompt_Fails() + { + IMcpClient client = await CreateMcpClientForServer(); + + await Assert.ThrowsAsync(async () => await client.GetPromptAsync( + nameof(SimplePrompts.ThrowsException), + cancellationToken: TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task Throws_Exception_On_Unknown_Prompt() + { + IMcpClient client = await CreateMcpClientForServer(); + + var e = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( + "NotRegisteredPrompt", + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Contains("'NotRegisteredPrompt'", e.Message); + } + + [Fact] + public async Task Throws_Exception_Missing_Parameter() + { + IMcpClient client = await CreateMcpClientForServer(); + + var e = await Assert.ThrowsAsync(async () => await client.GetPromptAsync( + nameof(SimplePrompts.ReturnsChatMessages), + cancellationToken: TestContext.Current.CancellationToken)); + + Assert.Contains("Missing required parameter", e.Message); + } + + [Fact] + public void WithPrompts_InvalidArgs_Throws() + { + Assert.Throws("promptTypes", () => _builder.WithPrompts((IEnumerable)null!)); + + IMcpServerBuilder nullBuilder = null!; + Assert.Throws("builder", () => nullBuilder.WithPrompts()); + Assert.Throws("builder", () => nullBuilder.WithPrompts(Array.Empty())); + Assert.Throws("builder", () => nullBuilder.WithPromptsFromAssembly()); + } + + [Fact] + public void Empty_Enumerables_Is_Allowed() + { + _builder.WithPrompts(promptTypes: []); // no exception + _builder.WithPrompts(); // no exception even though no prompts exposed + _builder.WithPromptsFromAssembly(typeof(AIFunction).Assembly); // no exception even though no prompts exposed + } + + [Fact] + public void Register_Prompts_From_Current_Assembly() + { + ServiceCollection sc = new(); + sc.AddMcpServer().WithPromptsFromAssembly(); + IServiceProvider services = sc.BuildServiceProvider(); + + Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == nameof(SimplePrompts.ReturnsChatMessages)); + } + + [Fact] + public void Register_Prompts_From_Multiple_Sources() + { + ServiceCollection sc = new(); + sc.AddMcpServer() + .WithPrompts() + .WithPrompts(); + IServiceProvider services = sc.BuildServiceProvider(); + + Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == nameof(SimplePrompts.ReturnsChatMessages)); + Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == nameof(SimplePrompts.ThrowsException)); + Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == nameof(SimplePrompts.ReturnsString)); + Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == nameof(MorePrompts.AnotherPrompt)); + } + + [McpServerToolType] + public sealed class SimplePrompts(ObjectWithId? id = null) + { + [McpServerPrompt, Description("Returns chat messages")] + public static ChatMessage[] ReturnsChatMessages([Description("The first parameter")] string message) => + [ + new(ChatRole.User, $"The prompt is: {message}"), + new(ChatRole.User, "Summarize."), + ]; + + + [McpServerPrompt, Description("Returns chat messages")] + public static ChatMessage[] ThrowsException([Description("The first parameter")] string message) => + throw new FormatException("uh oh"); + + [McpServerPrompt, Description("Returns chat messages")] + public string ReturnsString([Description("The first parameter")] string message) => + $"The prompt is: {message}. The id is {id}."; + } + + [McpServerToolType] + public sealed class MorePrompts + { + [McpServerPrompt] + public static PromptMessage AnotherPrompt() => + new PromptMessage + { + Role = Role.User, + Content = new() { Text = "hello", Type = "text" }, + }; + } + + public class ObjectWithId + { + public string Id { get; set; } = Guid.NewGuid().ToString("N"); + } +} diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index baeb2c7c..2a1ddafa 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -3,7 +3,6 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using ModelContextProtocol.Client; -using ModelContextProtocol.Configuration; using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; @@ -356,16 +355,16 @@ public async Task Throws_Exception_On_Unknown_Tool() Assert.Contains("'NotRegisteredTool'", e.Message); } - [Fact(Skip = "https://github.com/dotnet/extensions/issues/6124")] - public async Task Throws_Exception_Missing_Parameter() + [Fact] + public async Task Returns_IsError_Missing_Parameter() { IMcpClient client = await CreateMcpClientForServer(); - var e = await Assert.ThrowsAsync(async () => await client.CallToolAsync( + var result = await client.CallToolAsync( "Echo", - cancellationToken: TestContext.Current.CancellationToken)); + cancellationToken: TestContext.Current.CancellationToken); - Assert.Equal("Missing required argument 'message'.", e.Message); + Assert.True(result.IsError); } [Fact] diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs index 3bd09c0e..1f8acb38 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsTransportsTests.cs @@ -1,5 +1,4 @@ -using ModelContextProtocol.Configuration; -using ModelContextProtocol.Protocol.Transport; +using ModelContextProtocol.Protocol.Transport; using Microsoft.Extensions.DependencyInjection; using Moq; diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs new file mode 100644 index 00000000..cca61721 --- /dev/null +++ b/tests/ModelContextProtocol.Tests/Server/McpServerPromptTests.cs @@ -0,0 +1,399 @@ +using Microsoft.Extensions.AI; +using Microsoft.Extensions.DependencyInjection; +using ModelContextProtocol.Protocol.Types; +using ModelContextProtocol.Server; +using Moq; +using System.Reflection; + +namespace ModelContextProtocol.Tests.Server; + +public class McpServerPromptTests +{ + [Fact] + public void Create_InvalidArgs_Throws() + { + Assert.Throws("function", () => McpServerPrompt.Create((AIFunction)null!)); + Assert.Throws("method", () => McpServerPrompt.Create((MethodInfo)null!)); + Assert.Throws("method", () => McpServerPrompt.Create((MethodInfo)null!, typeof(object))); + Assert.Throws("targetType", () => McpServerPrompt.Create(typeof(McpServerPromptTests).GetMethod(nameof(Create_InvalidArgs_Throws))!, (Type)null!)); + Assert.Throws("method", () => McpServerPrompt.Create((Delegate)null!)); + } + + [Fact] + public async Task SupportsIMcpServer() + { + Mock mockServer = new(); + + McpServerPrompt prompt = McpServerPrompt.Create((IMcpServer server) => + { + Assert.Same(mockServer.Object, server); + return new ChatMessage(ChatRole.User, "Hello"); + }); + + Assert.DoesNotContain("server", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []); + + var result = await prompt.GetAsync( + new RequestContext(mockServer.Object, null), + TestContext.Current.CancellationToken); + Assert.NotNull(result); + Assert.NotNull(result.Messages); + Assert.Single(result.Messages); + Assert.Equal("Hello", result.Messages[0].Content.Text); + } + + [Fact] + public async Task SupportsServiceFromDI() + { + MyService expectedMyService = new(); + + ServiceCollection sc = new(); + sc.AddSingleton(expectedMyService); + IServiceProvider services = sc.BuildServiceProvider(); + + McpServerPrompt prompt = McpServerPrompt.Create((MyService actualMyService, int? something = null) => + { + Assert.Same(expectedMyService, actualMyService); + return new PromptMessage() { Role = Role.Assistant, Content = new() { Text = "Hello", Type = "text" } }; + }, new() { Services = services }); + + Assert.Contains("something", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []); + Assert.DoesNotContain("actualMyService", prompt.ProtocolPrompt.Arguments?.Select(a => a.Name) ?? []); + + Mock mockServer = new(); + + await Assert.ThrowsAsync(async () => await prompt.GetAsync( + new RequestContext(mockServer.Object, null), + TestContext.Current.CancellationToken)); + + mockServer.SetupGet(x => x.Services).Returns(services); + + var result = await prompt.GetAsync( + new RequestContext(mockServer.Object, null), + TestContext.Current.CancellationToken); + Assert.Equal("Hello", result.Messages[0].Content.Text); + } + + [Fact] + public async Task SupportsOptionalServiceFromDI() + { + MyService expectedMyService = new(); + + ServiceCollection sc = new(); + sc.AddSingleton(expectedMyService); + IServiceProvider services = sc.BuildServiceProvider(); + + McpServerPrompt prompt = McpServerPrompt.Create((MyService? actualMyService = null) => + { + Assert.Null(actualMyService); + return new PromptMessage() { Role = Role.Assistant, Content = new() { Text = "Hello", Type = "text" } }; + }, new() { Services = services }); + + var result = await prompt.GetAsync( + new RequestContext(null!, null), + TestContext.Current.CancellationToken); + Assert.Equal("Hello", result.Messages[0].Content.Text); + } + + [Fact] + public async Task SupportsDisposingInstantiatedDisposableTargets() + { + McpServerPrompt prompt1 = McpServerPrompt.Create( + typeof(DisposablePromptType).GetMethod(nameof(DisposablePromptType.InstanceMethod))!, + typeof(DisposablePromptType)); + + var result = await prompt1.GetAsync( + new RequestContext(null!, null), + TestContext.Current.CancellationToken); + Assert.Equal("disposals:1", result.Messages[0].Content.Text); + } + + [Fact] + public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableTargets() + { + McpServerPrompt prompt1 = McpServerPrompt.Create( + typeof(AsyncDisposablePromptType).GetMethod(nameof(AsyncDisposablePromptType.InstanceMethod))!, + typeof(AsyncDisposablePromptType)); + + var result = await prompt1.GetAsync( + new RequestContext(null!, null), + TestContext.Current.CancellationToken); + Assert.Equal("asyncDisposals:1", result.Messages[0].Content.Text); + } + + [Fact] + public async Task SupportsAsyncDisposingInstantiatedAsyncDisposableAndDisposableTargets() + { + McpServerPrompt prompt1 = McpServerPrompt.Create( + typeof(AsyncDisposableAndDisposablePromptType).GetMethod(nameof(AsyncDisposableAndDisposablePromptType.InstanceMethod))!, + typeof(AsyncDisposableAndDisposablePromptType)); + + var result = await prompt1.GetAsync( + new RequestContext(null!, null), + TestContext.Current.CancellationToken); + Assert.Equal("disposals:0, asyncDisposals:1", result.Messages[0].Content.Text); + } + + [Fact] + public async Task CanReturnGetPromptResult() + { + GetPromptResult expected = new(); + + McpServerPrompt prompt = McpServerPrompt.Create(() => + { + return expected; + }); + + var actual = await prompt.GetAsync( + new RequestContext(null!, null), + TestContext.Current.CancellationToken); + + Assert.Same(expected, actual); + } + + [Fact] + public async Task CanReturnText() + { + string expected = "hello"; + + McpServerPrompt prompt = McpServerPrompt.Create(() => + { + return expected; + }); + + var actual = await prompt.GetAsync( + new RequestContext(null!, null), + TestContext.Current.CancellationToken); + + Assert.NotNull(actual); + Assert.NotNull(actual.Messages); + Assert.Single(actual.Messages); + Assert.Equal(Role.User, actual.Messages[0].Role); + Assert.Equal("text", actual.Messages[0].Content.Type); + Assert.Equal(expected, actual.Messages[0].Content.Text); + } + + [Fact] + public async Task CanReturnPromptMessage() + { + PromptMessage expected = new() + { + Role = Role.User, + Content = new() { Text = "hello", Type = "text" } + }; + + McpServerPrompt prompt = McpServerPrompt.Create(() => + { + return expected; + }); + + var actual = await prompt.GetAsync( + new RequestContext(null!, null), + TestContext.Current.CancellationToken); + + Assert.NotNull(actual); + Assert.NotNull(actual.Messages); + Assert.Single(actual.Messages); + Assert.Same(expected, actual.Messages[0]); + } + + [Fact] + public async Task CanReturnPromptMessages() + { + PromptMessage[] expected = [ + new() + { + Role = Role.User, + Content = new() { Text = "hello", Type = "text" } + }, + new() + { + Role = Role.Assistant, + Content = new() { Text = "hello again", Type = "text" } + } + ]; + + McpServerPrompt prompt = McpServerPrompt.Create(() => + { + return expected; + }); + + var actual = await prompt.GetAsync( + new RequestContext(null!, null), + TestContext.Current.CancellationToken); + + Assert.NotNull(actual); + Assert.NotNull(actual.Messages); + Assert.Equal(2, actual.Messages.Count); + Assert.Equal(Role.User, actual.Messages[0].Role); + Assert.Equal("text", actual.Messages[0].Content.Type); + Assert.Equal("hello", actual.Messages[0].Content.Text); + Assert.Equal(Role.Assistant, actual.Messages[1].Role); + Assert.Equal("text", actual.Messages[1].Content.Type); + Assert.Equal("hello again", actual.Messages[1].Content.Text); + } + + [Fact] + public async Task CanReturnChatMessage() + { + PromptMessage expected = new() + { + Role = Role.User, + Content = new() { Text = "hello", Type = "text" } + }; + + McpServerPrompt prompt = McpServerPrompt.Create(() => + { + return expected.ToChatMessage(); + }); + + var actual = await prompt.GetAsync( + new RequestContext(null!, null), + TestContext.Current.CancellationToken); + + Assert.NotNull(actual); + Assert.NotNull(actual.Messages); + Assert.Single(actual.Messages); + Assert.Equal(Role.User, actual.Messages[0].Role); + Assert.Equal("text", actual.Messages[0].Content.Type); + Assert.Equal("hello", actual.Messages[0].Content.Text); + } + + [Fact] + public async Task CanReturnChatMessages() + { + PromptMessage[] expected = [ + new() + { + Role = Role.User, + Content = new() { Text = "hello", Type = "text" } + }, + new() + { + Role = Role.Assistant, + Content = new() { Text = "hello again", Type = "text" } + } + ]; + + McpServerPrompt prompt = McpServerPrompt.Create(() => + { + return expected.Select(p => p.ToChatMessage()); + }); + + var actual = await prompt.GetAsync( + new RequestContext(null!, null), + TestContext.Current.CancellationToken); + + Assert.NotNull(actual); + Assert.NotNull(actual.Messages); + Assert.Equal(2, actual.Messages.Count); + Assert.Equal(Role.User, actual.Messages[0].Role); + Assert.Equal("text", actual.Messages[0].Content.Type); + Assert.Equal("hello", actual.Messages[0].Content.Text); + Assert.Equal(Role.Assistant, actual.Messages[1].Role); + Assert.Equal("text", actual.Messages[1].Content.Type); + Assert.Equal("hello again", actual.Messages[1].Content.Text); + } + + [Fact] + public async Task ThrowsForNullReturn() + { + McpServerPrompt prompt = McpServerPrompt.Create(() => + { + return (string)null!; + }); + + await Assert.ThrowsAsync(async () => await prompt.GetAsync( + new RequestContext(null!, null), + TestContext.Current.CancellationToken)); + } + + [Fact] + public async Task ThrowsForUnexpectedTypeReturn() + { + McpServerPrompt prompt = McpServerPrompt.Create(() => + { + return new object(); + }); + + await Assert.ThrowsAsync(async () => await prompt.GetAsync( + new RequestContext(null!, null), + TestContext.Current.CancellationToken)); + } + + private sealed class MyService; + + private class DisposablePromptType : IDisposable + { + public int Disposals { get; private set; } + private ChatMessage _message = new ChatMessage(ChatRole.User, ""); + + public void Dispose() + { + Disposals++; + ((TextContent)_message.Contents[0]).Text = $"disposals:{Disposals}"; + } + + public ChatMessage InstanceMethod() + { + if (Disposals != 0) + { + throw new InvalidOperationException("Dispose was called"); + } + + return _message; + } + } + + private class AsyncDisposablePromptType : IAsyncDisposable + { + public int AsyncDisposals { get; private set; } + private ChatMessage _message = new ChatMessage(ChatRole.User, ""); + + public ValueTask DisposeAsync() + { + AsyncDisposals++; + ((TextContent)_message.Contents[0]).Text = $"asyncDisposals:{AsyncDisposals}"; + return default; + } + + public ChatMessage InstanceMethod() + { + if (AsyncDisposals != 0) + { + throw new InvalidOperationException("DisposeAsync was called"); + } + + return _message; + } + } + + private class AsyncDisposableAndDisposablePromptType : IAsyncDisposable, IDisposable + { + public int Disposals { get; private set; } + public int AsyncDisposals { get; private set; } + private ChatMessage _message = new ChatMessage(ChatRole.User, ""); + + public void Dispose() + { + Disposals++; + ((TextContent)_message.Contents[0]).Text = $"disposals:{Disposals}, asyncDisposals:{AsyncDisposals}"; + } + + public ValueTask DisposeAsync() + { + AsyncDisposals++; + ((TextContent)_message.Contents[0]).Text = $"disposals:{Disposals}, asyncDisposals:{AsyncDisposals}"; + return default; + } + + public ChatMessage InstanceMethod() + { + if (Disposals + AsyncDisposals != 0) + { + throw new InvalidOperationException("Dispose and/or DisposeAsync was called"); + } + + return _message; + } + } +} diff --git a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs index 5c04caad..2d747d9d 100644 --- a/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs +++ b/tests/ModelContextProtocol.Tests/Server/McpServerToolTests.cs @@ -19,6 +19,11 @@ public void Create_InvalidArgs_Throws() Assert.Throws("method", () => McpServerTool.Create((MethodInfo)null!, typeof(object))); Assert.Throws("targetType", () => McpServerTool.Create(typeof(McpServerToolTests).GetMethod(nameof(Create_InvalidArgs_Throws))!, (Type)null!)); Assert.Throws("method", () => McpServerTool.Create((Delegate)null!)); + + Assert.NotNull(McpServerTool.Create(typeof(DisposableToolType).GetMethod(nameof(DisposableToolType.InstanceMethod))!, new DisposableToolType())); + Assert.NotNull(McpServerTool.Create(typeof(DisposableToolType).GetMethod(nameof(DisposableToolType.StaticMethod))!)); + Assert.Throws("target", () => McpServerTool.Create(typeof(DisposableToolType).GetMethod(nameof(DisposableToolType.InstanceMethod))!, target: null!)); + Assert.Throws("target", () => McpServerTool.Create(typeof(DisposableToolType).GetMethod(nameof(DisposableToolType.StaticMethod))!, new DisposableToolType())); } [Fact] @@ -339,6 +344,11 @@ public object InstanceMethod() return this; } + + public static object StaticMethod() + { + return "42"; + } } private class AsyncDisposableToolType : IAsyncDisposable diff --git a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs index a21e3d61..1874953d 100644 --- a/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs +++ b/tests/ModelContextProtocol.Tests/SseIntegrationTests.cs @@ -1,10 +1,8 @@ using Microsoft.Extensions.Logging; using ModelContextProtocol.Client; -using ModelContextProtocol.Configuration; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Protocol.Types; using ModelContextProtocol.Tests.Utils; -using System.Reflection; using System.Text.Json; namespace ModelContextProtocol.Tests; diff --git a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs index 16df076a..c3b0fbbf 100644 --- a/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs +++ b/tests/ModelContextProtocol.Tests/SseServerIntegrationTestFixture.cs @@ -1,6 +1,5 @@ using Microsoft.Extensions.Logging; using ModelContextProtocol.Client; -using ModelContextProtocol.Configuration; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Test.Utils; using ModelContextProtocol.Tests.Utils; diff --git a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs index 8e64dd3b..23061cd9 100644 --- a/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/SseClientTransportTests.cs @@ -1,5 +1,4 @@ -using ModelContextProtocol.Configuration; -using ModelContextProtocol.Protocol.Messages; +using ModelContextProtocol.Protocol.Messages; using ModelContextProtocol.Protocol.Transport; using ModelContextProtocol.Tests.Utils; using System.IO.Pipelines;