diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java index e4479ff3..22e89389 100644 --- a/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java @@ -118,4 +118,17 @@ public static ChatCompletionRequestFunctionCall of(String name) { } } + + /** + * A list of tools the model may call. Currently, only functions are supported as a tool. + */ + List tools; + + /** + * Controls which (if any) function is called by the model. none means the model will not call a function and instead generates a message. auto means the model can pick between generating a message or calling a function. + */ + @JsonProperty("tool_choice") + String toolChoice; + + } diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessage.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessage.java index 912a71f0..d8ed189b 100644 --- a/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessage.java +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessage.java @@ -4,6 +4,8 @@ import com.fasterxml.jackson.annotation.JsonProperty; import lombok.*; +import java.util.List; + /** *

Each object has a role (either "system", "user", or "assistant") and content (the content of the message). Conversations can be as short as 1 message or fill many pages.

*

Typically, a conversation is formatted with a system message first, followed by alternating user and assistant messages.

@@ -30,6 +32,10 @@ public class ChatMessage { String content; //name is optional, The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters. String name; + + @JsonProperty("tool_calls") + List toolCalls; + @JsonProperty("function_call") ChatFunctionCall functionCall; diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageRole.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageRole.java index 255641e0..ad7b04f0 100644 --- a/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageRole.java +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageRole.java @@ -7,7 +7,8 @@ public enum ChatMessageRole { SYSTEM("system"), USER("user"), ASSISTANT("assistant"), - FUNCTION("function"); + FUNCTION("function"), + TOOL("tool"); private final String value; diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageTool.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageTool.java new file mode 100644 index 00000000..9f7466d2 --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatMessageTool.java @@ -0,0 +1,24 @@ +package com.theokanning.openai.completion.chat; + +import com.fasterxml.jackson.annotation.JsonProperty; +import lombok.Builder; +import lombok.Data; + +/** + *

Chat Message specialization for tool system + *

+ * + * see here for more info Function Calling + */ + +@Data +public class ChatMessageTool extends ChatMessage { + + @JsonProperty("tool_call_id") + private String toolCallId; + + public ChatMessageTool(String toolCallId, String role, String content, String name) { + super(role,content,name); + this.toolCallId = toolCallId; + } +} diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatTool.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatTool.java new file mode 100644 index 00000000..a8d4aae6 --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatTool.java @@ -0,0 +1,22 @@ +package com.theokanning.openai.completion.chat; + +import lombok.Data; +import lombok.NoArgsConstructor; +import lombok.NonNull; + +@Data +@NoArgsConstructor +public class ChatTool { + + + /** + * The name of the tool being called, only function supported for now. + */ + @NonNull + private String type = "function"; + + + @NonNull + private T function; + +} diff --git a/api/src/main/java/com/theokanning/openai/completion/chat/ChatToolCalls.java b/api/src/main/java/com/theokanning/openai/completion/chat/ChatToolCalls.java new file mode 100644 index 00000000..8809195f --- /dev/null +++ b/api/src/main/java/com/theokanning/openai/completion/chat/ChatToolCalls.java @@ -0,0 +1,28 @@ +package com.theokanning.openai.completion.chat; + +import com.fasterxml.jackson.databind.JsonNode; +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@AllArgsConstructor +@NoArgsConstructor +public class ChatToolCalls { + + /** + * The ID of the tool call + */ + String id; + + /** + * The type of the tool. Currently, only function is supported. + */ + String type; + + /** + * The function that the model called. + */ + ChatFunctionCall function; + +} diff --git a/service/src/main/java/com/theokanning/openai/service/FunctionExecutor.java b/service/src/main/java/com/theokanning/openai/service/FunctionExecutor.java index 5d143a95..bf6d837f 100644 --- a/service/src/main/java/com/theokanning/openai/service/FunctionExecutor.java +++ b/service/src/main/java/com/theokanning/openai/service/FunctionExecutor.java @@ -5,10 +5,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.TextNode; -import com.theokanning.openai.completion.chat.ChatFunction; -import com.theokanning.openai.completion.chat.ChatFunctionCall; -import com.theokanning.openai.completion.chat.ChatMessage; -import com.theokanning.openai.completion.chat.ChatMessageRole; +import com.theokanning.openai.completion.chat.*; import java.util.*; diff --git a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java index 25f0defb..cc6c97ef 100644 --- a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java +++ b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java @@ -83,7 +83,6 @@ void streamChatCompletion() { assertTrue(chunks.size() > 0); assertNotNull(chunks.get(0).getChoices().get(0)); } - @Test void createChatCompletionWithFunctions() { final List functions = Collections.singletonList(ChatFunction.builder() @@ -300,4 +299,76 @@ void streamChatCompletionWithDynamicFunctions() { assertNotNull(accumulatedMessage.getFunctionCall().getArguments().get("unit")); } -} + @Test + void createChatCompletionWithToolFunctions() { + + final List functions = Collections.singletonList(ChatFunction.builder() + .name("get_weather") + .description("Get the current weather in a given location") + .executor(Weather.class, w -> new WeatherResponse(w.location, w.unit, 25, "sunny")) + .build()); + final FunctionExecutor functionExecutor = new FunctionExecutor(functions); + final ChatTool tool = new ChatTool(); + tool.setFunction(functionExecutor.getFunctions().get(0)); + final List messages = new ArrayList<>(); + final ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), "You are a helpful assistant."); + final ChatMessage userMessage = new ChatMessage(ChatMessageRole.USER.value(), "What is the weather in Monterrey, Nuevo León?"); + messages.add(systemMessage); + messages.add(userMessage); + + ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest + .builder() + .model("gpt-3.5-turbo-0613") + .messages(messages) + .tools(List.of(tool)) + .toolChoice("auto") + .n(1) + .maxTokens(100) + .logitBias(new HashMap<>()) + .build(); + + ChatCompletionChoice choice = service.createChatCompletion(chatCompletionRequest).getChoices().get(0); + assertEquals("tool_calls", choice.getFinishReason()); + + assertEquals("get_weather", choice.getMessage().getToolCalls().get(0).getFunction().getName()); + assertInstanceOf(ObjectNode.class, choice.getMessage().getToolCalls().get(0).getFunction().getArguments()); + + ChatMessage callResponse = functionExecutor.executeAndConvertToMessageHandlingExceptions(choice.getMessage().getToolCalls().get(0).getFunction()); + assertNotEquals("error", callResponse.getName()); + + // this performs an unchecked cast + WeatherResponse functionExecutionResponse = functionExecutor.execute(choice.getMessage().getToolCalls().get(0).getFunction()); + assertInstanceOf(WeatherResponse.class, functionExecutionResponse); + assertEquals(25, functionExecutionResponse.temperature); + + JsonNode jsonFunctionExecutionResponse = functionExecutor.executeAndConvertToJson(choice.getMessage().getToolCalls().get(0).getFunction()); + assertInstanceOf(ObjectNode.class, jsonFunctionExecutionResponse); + assertEquals("25", jsonFunctionExecutionResponse.get("temperature").asText()); + + //Construct message for tool_calls + ChatMessageTool chatMessageTool = new ChatMessageTool(choice.getMessage().getToolCalls().get(0).getId(), + ChatMessageRole.TOOL.value(), + jsonFunctionExecutionResponse.toString(), + choice.getMessage().getToolCalls().get(0).getFunction().getName()); + + messages.add(choice.getMessage()); + messages.add(chatMessageTool); + + ChatCompletionRequest chatCompletionRequest2 = ChatCompletionRequest + .builder() + .model("gpt-3.5-turbo-0613") + .messages(messages) + .tools(List.of(tool)) + .toolChoice("auto") + .n(1) + .maxTokens(100) + .logitBias(new HashMap<>()) + .build(); + + ChatCompletionChoice choice2 = service.createChatCompletion(chatCompletionRequest2).getChoices().get(0); + assertNotEquals("tool_calls", choice2.getFinishReason()); // could be stop or length, but should not be function_call + assertNull(choice2.getMessage().getFunctionCall()); + assertNotNull(choice2.getMessage().getContent()); + } + +} \ No newline at end of file