diff --git a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs index 05872aeb..b389979e 100644 --- a/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs +++ b/src/ModelContextProtocol/Configuration/McpServerBuilderExtensions.cs @@ -57,7 +57,29 @@ public static partial class McpServerBuilderExtensions /// Adds instances to the service collection backing . /// The builder instance. - /// Types with marked methods to add as tools to the server. + /// The instances to add to the server. + /// The builder provided in . + /// is . + /// is . + public static IMcpServerBuilder WithTools(this IMcpServerBuilder builder, IEnumerable tools) + { + Throw.IfNull(builder); + Throw.IfNull(tools); + + foreach (var tool in tools) + { + if (tool is not null) + { + builder.Services.AddSingleton(tool); + } + } + + return builder; + } + + /// Adds instances to the service collection backing . + /// The builder instance. + /// Types with -attributed methods to add as tools to the server. /// The serializer options governing tool parameter marshalling. /// The builder provided in . /// is . @@ -173,6 +195,28 @@ where t.GetCustomAttribute() is not null return builder; } + /// Adds instances to the service collection backing . + /// The builder instance. + /// The instances to add to the server. + /// The builder provided in . + /// is . + /// is . + public static IMcpServerBuilder WithPrompts(this IMcpServerBuilder builder, IEnumerable prompts) + { + Throw.IfNull(builder); + Throw.IfNull(prompts); + + foreach (var prompt in prompts) + { + if (prompt is not null) + { + builder.Services.AddSingleton(prompt); + } + } + + return builder; + } + /// Adds instances to the service collection backing . /// The builder instance. /// Types with marked methods to add as prompts to the server. diff --git a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs index a5bf96fe..054d5227 100644 --- a/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs +++ b/tests/ModelContextProtocol.Tests/Client/McpClientExtensionsTests.cs @@ -25,10 +25,10 @@ protected override void ConfigureServices(ServiceCollection services, IMcpServer for (int f = 0; f < 10; f++) { string name = $"Method{f}"; - services.AddSingleton(McpServerTool.Create((int i) => $"{name} Result {i}", new() { Name = name })); + mcpServerBuilder.WithTools([McpServerTool.Create((int i) => $"{name} Result {i}", new() { Name = name })]); } - services.AddSingleton(McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaAttr" })); - services.AddSingleton(McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaOptions", Destructive = true, OpenWorld = false, ReadOnly = true })); + mcpServerBuilder.WithTools([McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaAttr" })]); + mcpServerBuilder.WithTools([McpServerTool.Create([McpServerTool(Destructive = false, OpenWorld = true)] (string i) => $"{i} Result", new() { Name = "ValuesSetViaOptions", Destructive = true, OpenWorld = false, ReadOnly = true })]); } [Theory] diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs index a66cbeca..1bede93c 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsPromptsTests.cs @@ -202,6 +202,7 @@ public void WithPrompts_InvalidArgs_Throws() { IMcpServerBuilder builder = new ServiceCollection().AddMcpServer(); + Assert.Throws("prompts", () => builder.WithPrompts((IEnumerable)null!)); Assert.Throws("promptTypes", () => builder.WithPrompts((IEnumerable)null!)); IMcpServerBuilder nullBuilder = null!; @@ -215,6 +216,7 @@ public void Empty_Enumerables_Is_Allowed() { IMcpServerBuilder builder = new ServiceCollection().AddMcpServer(); + builder.WithPrompts(prompts: []); // no exception builder.WithPrompts(promptTypes: []); // no exception builder.WithPrompts(); // no exception even though no prompts exposed builder.WithPromptsFromAssembly(typeof(AIFunction).Assembly); // no exception even though no prompts exposed @@ -236,13 +238,15 @@ public void Register_Prompts_From_Multiple_Sources() ServiceCollection sc = new(); sc.AddMcpServer() .WithPrompts() - .WithPrompts(JsonContext4.Default.Options); + .WithPrompts(JsonContext4.Default.Options) + .WithPrompts([McpServerPrompt.Create(() => "42", new() { Name = "Returns42" })]); IServiceProvider services = sc.BuildServiceProvider(); Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == nameof(SimplePrompts.ReturnsChatMessages)); Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == nameof(SimplePrompts.ThrowsException)); Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == nameof(SimplePrompts.ReturnsString)); Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == nameof(MorePrompts.AnotherPrompt)); + Assert.Contains(services.GetServices(), t => t.ProtocolPrompt.Name == "Returns42"); } [McpServerPromptType] diff --git a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs index 40ca2446..16a69cdf 100644 --- a/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs +++ b/tests/ModelContextProtocol.Tests/Configuration/McpServerBuilderExtensionsToolsTests.cs @@ -408,6 +408,7 @@ public void WithTools_InvalidArgs_Throws() { IMcpServerBuilder builder = new ServiceCollection().AddMcpServer(); + Assert.Throws("tools", () => builder.WithTools((IEnumerable)null!)); Assert.Throws("toolTypes", () => builder.WithTools((IEnumerable)null!)); IMcpServerBuilder nullBuilder = null!; @@ -421,6 +422,7 @@ public void Empty_Enumerables_Is_Allowed() { IMcpServerBuilder builder = new ServiceCollection().AddMcpServer(); + builder.WithTools(tools: []); // no exception builder.WithTools(toolTypes: []); // no exception builder.WithTools(); // no exception even though no tools exposed builder.WithToolsFromAssembly(typeof(AIFunction).Assembly); // no exception even though no tools exposed @@ -539,7 +541,8 @@ public void Register_Tools_From_Multiple_Sources() sc.AddMcpServer() .WithTools(serializerOptions: BuilderToolsJsonContext.Default.Options) .WithTools(serializerOptions: BuilderToolsJsonContext.Default.Options) - .WithTools([typeof(ToolTypeWithNoAttribute)], BuilderToolsJsonContext.Default.Options); + .WithTools([typeof(ToolTypeWithNoAttribute)], BuilderToolsJsonContext.Default.Options) + .WithTools([McpServerTool.Create(() => "42", new() { Name = "Returns42" })]); IServiceProvider services = sc.BuildServiceProvider(); Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "double_echo"); @@ -547,6 +550,7 @@ public void Register_Tools_From_Multiple_Sources() Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "MethodB"); Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "MethodC"); Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "MethodD"); + Assert.Contains(services.GetServices(), t => t.ProtocolTool.Name == "Returns42"); } [Fact]