Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(mcp): Add builder for CreateMessageRequest #60

Merged
merged 3 commits into from
Mar 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult;
import io.modelcontextprotocol.spec.McpSchema.InitializeResult;
import io.modelcontextprotocol.spec.McpSchema.ModelPreferences;
import io.modelcontextprotocol.spec.McpSchema.Role;
import io.modelcontextprotocol.spec.McpSchema.Root;
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
Expand All @@ -45,6 +46,7 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.awaitility.Awaitility.await;
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.mock;

public class WebFluxSseIntegrationTests {
Expand Down Expand Up @@ -142,13 +144,16 @@ void testCreateMessageSuccess(String clientType) throws InterruptedException {
McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {

var messages = List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER,
new McpSchema.TextContent("Test message")));
var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0);

var craeteMessageRequest = new McpSchema.CreateMessageRequest(messages, modelPrefs, null,
McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(),
Map.of());
var craeteMessageRequest = McpSchema.CreateMessageRequest.builder()
.messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER,
new McpSchema.TextContent("Test message"))))
.modelPreferences(ModelPreferences.builder()
.hints(List.of())
.costPriority(1.0)
.speedPriority(1.0)
.intelligencePriority(1.0)
.build())
.build();

StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> {
assertThat(result).isNotNull();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult;
import io.modelcontextprotocol.spec.McpSchema.InitializeResult;
import io.modelcontextprotocol.spec.McpSchema.ModelPreferences;
import io.modelcontextprotocol.spec.McpSchema.Role;
import io.modelcontextprotocol.spec.McpSchema.Root;
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
Expand All @@ -45,6 +46,7 @@

import static org.assertj.core.api.Assertions.assertThat;
import static org.awaitility.Awaitility.await;
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.mock;

public class WebMvcSseIntegrationTests {
Expand Down Expand Up @@ -199,13 +201,16 @@ void testCreateMessageSuccess() throws InterruptedException {
McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {

var messages = List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER,
new McpSchema.TextContent("Test message")));
var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0);

var craeteMessageRequest = new McpSchema.CreateMessageRequest(messages, modelPrefs, null,
McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(),
Map.of());
var craeteMessageRequest = McpSchema.CreateMessageRequest.builder()
.messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER,
new McpSchema.TextContent("Test message"))))
.modelPreferences(ModelPreferences.builder()
.hints(List.of())
.costPriority(1.0)
.speedPriority(1.0)
.intelligencePriority(1.0)
.build())
.build();

StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> {
assertThat(result).isNotNull();
Expand Down
117 changes: 112 additions & 5 deletions mcp/src/main/java/io/modelcontextprotocol/spec/McpSchema.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package io.modelcontextprotocol.spec;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -763,15 +764,61 @@ public record CallToolResult( // @formatter:off
@JsonInclude(JsonInclude.Include.NON_ABSENT)
@JsonIgnoreProperties(ignoreUnknown = true)
public record ModelPreferences(// @formatter:off
@JsonProperty("hints") List<ModelHint> hints,
@JsonProperty("costPriority") Double costPriority,
@JsonProperty("speedPriority") Double speedPriority,
@JsonProperty("intelligencePriority") Double intelligencePriority) {
} // @formatter:on
@JsonProperty("hints") List<ModelHint> hints,
@JsonProperty("costPriority") Double costPriority,
@JsonProperty("speedPriority") Double speedPriority,
@JsonProperty("intelligencePriority") Double intelligencePriority) {

public static Builder builder() {
return new Builder();
}

public static class Builder {
private List<ModelHint> hints;
private Double costPriority;
private Double speedPriority;
private Double intelligencePriority;

public Builder hints(List<ModelHint> hints) {
this.hints = hints;
return this;
}

public Builder addHint(String name) {
if (this.hints == null) {
this.hints = new ArrayList<>();
}
this.hints.add(new ModelHint(name));
return this;
}

public Builder costPriority(Double costPriority) {
this.costPriority = costPriority;
return this;
}

public Builder speedPriority(Double speedPriority) {
this.speedPriority = speedPriority;
return this;
}

public Builder intelligencePriority(Double intelligencePriority) {
this.intelligencePriority = intelligencePriority;
return this;
}

public ModelPreferences build() {
return new ModelPreferences(hints, costPriority, speedPriority, intelligencePriority);
}
}
} // @formatter:on

@JsonInclude(JsonInclude.Include.NON_ABSENT)
@JsonIgnoreProperties(ignoreUnknown = true)
public record ModelHint(@JsonProperty("name") String name) {
public static ModelHint of(String name) {
return new ModelHint(name);
}
}

@JsonInclude(JsonInclude.Include.NON_ABSENT)
Expand Down Expand Up @@ -799,6 +846,66 @@ public enum ContextInclusionStrategy {
@JsonProperty("thisServer") THIS_SERVER,
@JsonProperty("allServers") ALL_SERVERS
}

public static Builder builder() {
return new Builder();
}

public static class Builder {
private List<SamplingMessage> messages;
private ModelPreferences modelPreferences;
private String systemPrompt;
private ContextInclusionStrategy includeContext;
private Double temperature;
private int maxTokens;
private List<String> stopSequences;
private Map<String, Object> metadata;

public Builder messages(List<SamplingMessage> messages) {
this.messages = messages;
return this;
}

public Builder modelPreferences(ModelPreferences modelPreferences) {
this.modelPreferences = modelPreferences;
return this;
}

public Builder systemPrompt(String systemPrompt) {
this.systemPrompt = systemPrompt;
return this;
}

public Builder includeContext(ContextInclusionStrategy includeContext) {
this.includeContext = includeContext;
return this;
}

public Builder temperature(Double temperature) {
this.temperature = temperature;
return this;
}

public Builder maxTokens(int maxTokens) {
this.maxTokens = maxTokens;
return this;
}

public Builder stopSequences(List<String> stopSequences) {
this.stopSequences = stopSequences;
return this;
}

public Builder metadata(Map<String, Object> metadata) {
this.metadata = metadata;
return this;
}

public CreateMessageRequest build() {
return new CreateMessageRequest(messages, modelPreferences, systemPrompt,
includeContext, temperature, maxTokens, stopSequences, metadata);
}
}
}// @formatter:on

@JsonInclude(JsonInclude.Include.NON_ABSENT)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.modelcontextprotocol.spec.McpSchema.CreateMessageRequest;
import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult;
import io.modelcontextprotocol.spec.McpSchema.InitializeResult;
import io.modelcontextprotocol.spec.McpSchema.ModelPreferences;
import io.modelcontextprotocol.spec.McpSchema.Role;
import io.modelcontextprotocol.spec.McpSchema.Root;
import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities;
Expand Down Expand Up @@ -162,13 +163,16 @@ void testCreateMessageSuccess() throws InterruptedException {
McpServerFeatures.AsyncToolSpecification tool = new McpServerFeatures.AsyncToolSpecification(
new McpSchema.Tool("tool1", "tool1 description", emptyJsonSchema), (exchange, request) -> {

var messages = List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER,
new McpSchema.TextContent("Test message")));
var modelPrefs = new McpSchema.ModelPreferences(List.of(), 1.0, 1.0, 1.0);

var craeteMessageRequest = new McpSchema.CreateMessageRequest(messages, modelPrefs, null,
McpSchema.CreateMessageRequest.ContextInclusionStrategy.NONE, null, 100, List.of(),
Map.of());
var craeteMessageRequest = McpSchema.CreateMessageRequest.builder()
.messages(List.of(new McpSchema.SamplingMessage(McpSchema.Role.USER,
new McpSchema.TextContent("Test message"))))
.modelPreferences(ModelPreferences.builder()
.hints(List.of())
.costPriority(1.0)
.speedPriority(1.0)
.intelligencePriority(1.0)
.build())
.build();

StepVerifier.create(exchange.createMessage(craeteMessageRequest)).consumeNextWith(result -> {
assertThat(result).isNotNull();
Expand Down
22 changes: 16 additions & 6 deletions mcp/src/test/java/io/modelcontextprotocol/spec/McpSchemaTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -524,10 +524,16 @@ void testCreateMessageRequest() throws Exception {
Map<String, Object> metadata = new HashMap<>();
metadata.put("session", "test-session");

McpSchema.CreateMessageRequest request = new McpSchema.CreateMessageRequest(Collections.singletonList(message),
preferences, "You are a helpful assistant",
McpSchema.CreateMessageRequest.ContextInclusionStrategy.THIS_SERVER, 0.7, 1000,
Arrays.asList("STOP", "END"), metadata);
McpSchema.CreateMessageRequest request = McpSchema.CreateMessageRequest.builder()
.messages(Collections.singletonList(message))
.modelPreferences(preferences)
.systemPrompt("You are a helpful assistant")
.includeContext(McpSchema.CreateMessageRequest.ContextInclusionStrategy.THIS_SERVER)
.temperature(0.7)
.maxTokens(1000)
.stopSequences(Arrays.asList("STOP", "END"))
.metadata(metadata)
.build();

String value = mapper.writeValueAsString(request);

Expand All @@ -543,8 +549,12 @@ void testCreateMessageRequest() throws Exception {
void testCreateMessageResult() throws Exception {
McpSchema.TextContent content = new McpSchema.TextContent("Assistant response");

McpSchema.CreateMessageResult result = new McpSchema.CreateMessageResult(McpSchema.Role.ASSISTANT, content,
"gpt-4", McpSchema.CreateMessageResult.StopReason.END_TURN);
McpSchema.CreateMessageResult result = McpSchema.CreateMessageResult.builder()
.role(McpSchema.Role.ASSISTANT)
.content(content)
.model("gpt-4")
.stopReason(McpSchema.CreateMessageResult.StopReason.END_TURN)
.build();

String value = mapper.writeValueAsString(result);

Expand Down