From 946ed151572c116b7892fbe9d659d10203147475 Mon Sep 17 00:00:00 2001 From: gdorsi Date: Wed, 13 Dec 2023 14:31:52 +0100 Subject: [PATCH 1/2] [Chat completion API] Support tools and tool_choice --- .../chat/ChatCompletionRequest.java | 13 ++++ .../openai/completion/chat/ChatMessage.java | 6 ++ .../completion/chat/ChatMessageRole.java | 3 +- .../completion/chat/ChatMessageTool.java | 24 ++++++ .../openai/completion/chat/ChatTool.java | 22 ++++++ .../openai/completion/chat/ChatToolCalls.java | 28 +++++++ .../openai/service/FunctionExecutor.java | 5 +- .../openai/service/ChatCompletionTest.java | 78 ++++++++++++++++++- 8 files changed, 172 insertions(+), 7 deletions(-) create mode 100644 api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageTool.java create mode 100644 api/src/main/java/com/theokanning/openai/completion/chat/ChatTool.java create mode 100644 api/src/main/java/com/theokanning/openai/completion/chat/ChatToolCalls.java diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java index e4479ff3..22e89389 100644 --- a/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java @@ -118,4 +118,17 @@ public static ChatCompletionRequestFunctionCall of(String name) { } } + + /** + * A list of tools the model may call. Currently, only functions are supported as a tool. + */ + List tools; + + /** + * Controls which (if any) function is called by the model. none means the model will not call a function and instead generates a message. auto means the model can pick between generating a message or calling a function. + */ + @JsonProperty("tool_choice") + String toolChoice; + + } diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessage.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessage.java index 912a71f0..d8ed189b 100644 --- a/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessage.java +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessage.java @@ -4,6 +4,8 @@ import com.fasterxml.jackson.annotation.JsonProperty; import lombok.*; +import java.util.List; + /** *

Each object has a role (either "system", "user", or "assistant") and content (the content of the message). Conversations can be as short as 1 message or fill many pages.

*

Typically, a conversation is formatted with a system message first, followed by alternating user and assistant messages.

@@ -30,6 +32,10 @@ public class ChatMessage { String content; //name is optional, The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters. String name; + + @JsonProperty("tool_calls") + List toolCalls; + @JsonProperty("function_call") ChatFunctionCall functionCall; diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageRole.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageRole.java index 255641e0..ad7b04f0 100644 --- a/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageRole.java +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageRole.java @@ -7,7 +7,8 @@ public enum ChatMessageRole { SYSTEM("system"), USER("user"), ASSISTANT("assistant"), - FUNCTION("function"); + FUNCTION("function"), + TOOL("tool"); private final String value; diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageTool.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageTool.java new file mode 100644 index 00000000..9f7466d2 --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageTool.java @@ -0,0 +1,24 @@ +package com.theokanning.openai.completion.chat; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Builder; +import lombok.Data; + +/** + *

Chat Message specialization for tool system + *

+ * + * see here for more info Function Calling + */ + +@Data +public class ChatMessageTool extends ChatMessage { + + @JsonProperty("tool_call_id") + private String toolCallId; + + public ChatMessageTool(String toolCallId, String role, String content, String name) { + super(role,content,name); + this.toolCallId = toolCallId; + } +} diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatTool.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatTool.java new file mode 100644 index 00000000..a8d4aae6 --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatTool.java @@ -0,0 +1,22 @@ +package com.theokanning.openai.completion.chat; + +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.NonNull; + +@Data +@NoArgsConstructor +public class ChatTool { + + + /** + * The name of the tool being called, only function supported for now. + */ + @NonNull + private String type = "function"; + + + @NonNull + private T function; + +} diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatToolCalls.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatToolCalls.java new file mode 100644 index 00000000..8809195f --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatToolCalls.java @@ -0,0 +1,28 @@ +package com.theokanning.openai.completion.chat; + +import com.fasterxml.jackson.databind.JsonNode; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@AllArgsConstructor +@NoArgsConstructor +public class ChatToolCalls { + + /** + * The ID of the tool call + */ + String id; + + /** + * The type of the tool. Currently, only function is supported. + */ + String type; + + /** + * The function that the model called. + */ + ChatFunctionCall function; + +} diff --git a/service/src/main/java/com/theokanning/openai/service/FunctionExecutor.java b/service/src/main/java/com/theokanning/openai/service/FunctionExecutor.java index 5d143a95..bf6d837f 100644 --- a/service/src/main/java/com/theokanning/openai/service/FunctionExecutor.java +++ b/service/src/main/java/com/theokanning/openai/service/FunctionExecutor.java @@ -5,10 +5,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.TextNode; -import com.theokanning.openai.completion.chat.ChatFunction; -import com.theokanning.openai.completion.chat.ChatFunctionCall; -import com.theokanning.openai.completion.chat.ChatMessage; -import com.theokanning.openai.completion.chat.ChatMessageRole; +import com.theokanning.openai.completion.chat.*; import java.util.*; diff --git a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java index 25f0defb..3b177821 100644 --- a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java +++ b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java @@ -83,7 +83,6 @@ void streamChatCompletion() { assertTrue(chunks.size() > 0); assertNotNull(chunks.get(0).getChoices().get(0)); } - @Test void createChatCompletionWithFunctions() { final List functions = Collections.singletonList(ChatFunction.builder() @@ -300,4 +299,79 @@ void streamChatCompletionWithDynamicFunctions() { assertNotNull(accumulatedMessage.getFunctionCall().getArguments().get("unit")); } -} + @Test + void createChatCompletionWithToolFunctions() { + final List functions = Collections.singletonList(ChatFunction.builder() + .name("get_weather") + .description("Get the current weather in a given location") + .executor(Weather.class, w -> new WeatherResponse(w.location, w.unit, 25, "sunny")) + .build()); + + final FunctionExecutor functionExecutor = new FunctionExecutor(functions); + final ChatTool tool = new ChatTool(); + //tool.setType("function"); + tool.setFunction(functionExecutor.getFunctions().get(0)); + + final List messages = new ArrayList<>(); + final ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), "You are a helpful assistant."); + final ChatMessage userMessage = new ChatMessage(ChatMessageRole.USER.value(), "What is the weather in Monterrey, Nuevo León?"); + messages.add(systemMessage); + messages.add(userMessage); + + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest + .builder() + .model("gpt-3.5-turbo-0613") + .messages(messages) + .tools(List.of(tool)) + .toolChoice("auto") + + .n(1) + .maxTokens(100) + .logitBias(new HashMap<>()) + .build(); + + ChatCompletionResult result = service.createChatCompletion(chatCompletionRequest); + ChatCompletionChoice choice = result.getChoices().get(0); + assertEquals("tool_calls", choice.getFinishReason()); + + assertEquals("get_weather", choice.getMessage().getToolCalls().get(0).getFunction().getName()); + assertInstanceOf(ObjectNode.class, choice.getMessage().getToolCalls().get(0).getFunction().getArguments()); + + ChatMessage callResponse = functionExecutor.executeAndConvertToMessageHandlingExceptions(choice.getMessage().getToolCalls().get(0).getFunction()); + assertNotEquals("error", callResponse.getName()); + + // this performs an unchecked cast + WeatherResponse functionExecutionResponse = functionExecutor.execute(choice.getMessage().getToolCalls().get(0).getFunction()); + assertInstanceOf(WeatherResponse.class, functionExecutionResponse); + assertEquals(25, functionExecutionResponse.temperature); + + JsonNode jsonFunctionExecutionResponse = functionExecutor.executeAndConvertToJson(choice.getMessage().getToolCalls().get(0).getFunction()); + assertInstanceOf(ObjectNode.class, jsonFunctionExecutionResponse); + assertEquals("25", jsonFunctionExecutionResponse.get("temperature").asText()); + + //Construct message for tool_calls + ChatMessageTool chatMessageTool = new ChatMessageTool(choice.getMessage().getToolCalls().get(0).getId(), + ChatMessageRole.TOOL.value(),jsonFunctionExecutionResponse.toString(), + choice.getMessage().getToolCalls().get(0).getFunction().getName()); + + messages.add(choice.getMessage()); + messages.add(chatMessageTool); + + ChatCompletionRequest chatCompletionRequest2 = ChatCompletionRequest + .builder() + .model("gpt-3.5-turbo-0613") + .messages(messages) + .tools(List.of(tool)) + .toolChoice("auto") + .n(1) + .maxTokens(100) + .logitBias(new HashMap<>()) + .build(); + + ChatCompletionChoice choice2 = service.createChatCompletion(chatCompletionRequest2).getChoices().get(0); + assertNotEquals("tool_calls", choice2.getFinishReason()); // could be stop or length, but should not be function_call + assertNull(choice2.getMessage().getFunctionCall()); + assertNotNull(choice2.getMessage().getContent()); + } + +} \ No newline at end of file From 00731d4580980fe1ffc3d6be0c2a4632f42fc349 Mon Sep 17 00:00:00 2001 From: gdorsi Date: Thu, 14 Dec 2023 08:23:41 +0100 Subject: [PATCH 2/2] Simplified unit test createChatCompletionWithToolFunctions --- .../openai/service/ChatCompletionTest.java | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java index 3b177821..cc6c97ef 100644 --- a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java +++ b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java @@ -301,17 +301,15 @@ void streamChatCompletionWithDynamicFunctions() { @Test void createChatCompletionWithToolFunctions() { + final List functions = Collections.singletonList(ChatFunction.builder() .name("get_weather") .description("Get the current weather in a given location") .executor(Weather.class, w -> new WeatherResponse(w.location, w.unit, 25, "sunny")) .build()); - final FunctionExecutor functionExecutor = new FunctionExecutor(functions); final ChatTool tool = new ChatTool(); - //tool.setType("function"); tool.setFunction(functionExecutor.getFunctions().get(0)); - final List messages = new ArrayList<>(); final ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), "You are a helpful assistant."); final ChatMessage userMessage = new ChatMessage(ChatMessageRole.USER.value(), "What is the weather in Monterrey, Nuevo León?"); @@ -324,14 +322,12 @@ void createChatCompletionWithToolFunctions() { .messages(messages) .tools(List.of(tool)) .toolChoice("auto") - .n(1) .maxTokens(100) .logitBias(new HashMap<>()) .build(); - ChatCompletionResult result = service.createChatCompletion(chatCompletionRequest); - ChatCompletionChoice choice = result.getChoices().get(0); + ChatCompletionChoice choice = service.createChatCompletion(chatCompletionRequest).getChoices().get(0); assertEquals("tool_calls", choice.getFinishReason()); assertEquals("get_weather", choice.getMessage().getToolCalls().get(0).getFunction().getName()); @@ -351,8 +347,9 @@ void createChatCompletionWithToolFunctions() { //Construct message for tool_calls ChatMessageTool chatMessageTool = new ChatMessageTool(choice.getMessage().getToolCalls().get(0).getId(), - ChatMessageRole.TOOL.value(),jsonFunctionExecutionResponse.toString(), - choice.getMessage().getToolCalls().get(0).getFunction().getName()); + ChatMessageRole.TOOL.value(), + jsonFunctionExecutionResponse.toString(), + choice.getMessage().getToolCalls().get(0).getFunction().getName()); messages.add(choice.getMessage()); messages.add(chatMessageTool);