Skip to content
This repository was archived by the owner on Jun 6, 2024. It is now read-only.

[Chat completion API] Support tools and tool_choice #437

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,17 @@ public static ChatCompletionRequestFunctionCall of(String name) {
}

}

/**
* A list of tools the model may call. Currently, only functions are supported as a tool.
*/
List<ChatTool> tools;

/**
* Controls which (if any) function is called by the model. none means the model will not call a function and instead generates a message. auto means the model can pick between generating a message or calling a function.
*/
@JsonProperty("tool_choice")
String toolChoice;


}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.*;

import java.util.List;

/**
* <p>Each object has a role (either "system", "user", or "assistant") and content (the content of the message). Conversations can be as short as 1 message or fill many pages.</p>
* <p>Typically, a conversation is formatted with a system message first, followed by alternating user and assistant messages.</p>
Expand All @@ -30,6 +32,10 @@ public class ChatMessage {
String content;
//name is optional, The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores, with a maximum length of 64 characters.
String name;

@JsonProperty("tool_calls")
List<ChatToolCalls> toolCalls;

@JsonProperty("function_call")
ChatFunctionCall functionCall;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ public enum ChatMessageRole {
SYSTEM("system"),
USER("user"),
ASSISTANT("assistant"),
FUNCTION("function");
FUNCTION("function"),
TOOL("tool");

private final String value;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package com.theokanning.openai.completion.chat;

import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Builder;
import lombok.Data;

/**
* <p>Chat Message specialization for tool system
* </p>
*
* see here for more info <a href="https://platform.openai.com/docs/guides/function-calling">Function Calling</a>
*/

@Data
public class ChatMessageTool extends ChatMessage {

@JsonProperty("tool_call_id")
private String toolCallId;

public ChatMessageTool(String toolCallId, String role, String content, String name) {
super(role,content,name);
this.toolCallId = toolCallId;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package com.theokanning.openai.completion.chat;

import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.NonNull;

@Data
@NoArgsConstructor
public class ChatTool<T> {


/**
* The name of the tool being called, only function supported for now.
*/
@NonNull
private String type = "function";


@NonNull
private T function;

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.theokanning.openai.completion.chat;

import com.fasterxml.jackson.databind.JsonNode;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;

@Data
@AllArgsConstructor
@NoArgsConstructor
public class ChatToolCalls {

/**
* The ID of the tool call
*/
String id;

/**
* The type of the tool. Currently, only function is supported.
*/
String type;

/**
* The function that the model called.
*/
ChatFunctionCall function;

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.fasterxml.jackson.databind.node.TextNode;
import com.theokanning.openai.completion.chat.ChatFunction;
import com.theokanning.openai.completion.chat.ChatFunctionCall;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.completion.chat.ChatMessageRole;
import com.theokanning.openai.completion.chat.*;

import java.util.*;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ void streamChatCompletion() {
assertTrue(chunks.size() > 0);
assertNotNull(chunks.get(0).getChoices().get(0));
}

@Test
void createChatCompletionWithFunctions() {
final List<ChatFunction> functions = Collections.singletonList(ChatFunction.builder()
Expand Down Expand Up @@ -300,4 +299,76 @@ void streamChatCompletionWithDynamicFunctions() {
assertNotNull(accumulatedMessage.getFunctionCall().getArguments().get("unit"));
}

}
@Test
void createChatCompletionWithToolFunctions() {

final List<ChatFunction> functions = Collections.singletonList(ChatFunction.builder()
.name("get_weather")
.description("Get the current weather in a given location")
.executor(Weather.class, w -> new WeatherResponse(w.location, w.unit, 25, "sunny"))
.build());
final FunctionExecutor functionExecutor = new FunctionExecutor(functions);
final ChatTool tool = new ChatTool();
tool.setFunction(functionExecutor.getFunctions().get(0));
final List<ChatMessage> messages = new ArrayList<>();
final ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), "You are a helpful assistant.");
final ChatMessage userMessage = new ChatMessage(ChatMessageRole.USER.value(), "What is the weather in Monterrey, Nuevo León?");
messages.add(systemMessage);
messages.add(userMessage);

ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
.builder()
.model("gpt-3.5-turbo-0613")
.messages(messages)
.tools(List.of(tool))
.toolChoice("auto")
.n(1)
.maxTokens(100)
.logitBias(new HashMap<>())
.build();

ChatCompletionChoice choice = service.createChatCompletion(chatCompletionRequest).getChoices().get(0);
assertEquals("tool_calls", choice.getFinishReason());

assertEquals("get_weather", choice.getMessage().getToolCalls().get(0).getFunction().getName());
assertInstanceOf(ObjectNode.class, choice.getMessage().getToolCalls().get(0).getFunction().getArguments());

ChatMessage callResponse = functionExecutor.executeAndConvertToMessageHandlingExceptions(choice.getMessage().getToolCalls().get(0).getFunction());
assertNotEquals("error", callResponse.getName());

// this performs an unchecked cast
WeatherResponse functionExecutionResponse = functionExecutor.execute(choice.getMessage().getToolCalls().get(0).getFunction());
assertInstanceOf(WeatherResponse.class, functionExecutionResponse);
assertEquals(25, functionExecutionResponse.temperature);

JsonNode jsonFunctionExecutionResponse = functionExecutor.executeAndConvertToJson(choice.getMessage().getToolCalls().get(0).getFunction());
assertInstanceOf(ObjectNode.class, jsonFunctionExecutionResponse);
assertEquals("25", jsonFunctionExecutionResponse.get("temperature").asText());

//Construct message for tool_calls
ChatMessageTool chatMessageTool = new ChatMessageTool(choice.getMessage().getToolCalls().get(0).getId(),
ChatMessageRole.TOOL.value(),
jsonFunctionExecutionResponse.toString(),
choice.getMessage().getToolCalls().get(0).getFunction().getName());

messages.add(choice.getMessage());
messages.add(chatMessageTool);

ChatCompletionRequest chatCompletionRequest2 = ChatCompletionRequest
.builder()
.model("gpt-3.5-turbo-0613")
.messages(messages)
.tools(List.of(tool))
.toolChoice("auto")
.n(1)
.maxTokens(100)
.logitBias(new HashMap<>())
.build();

ChatCompletionChoice choice2 = service.createChatCompletion(chatCompletionRequest2).getChoices().get(0);
assertNotEquals("tool_calls", choice2.getFinishReason()); // could be stop or length, but should not be function_call
assertNull(choice2.getMessage().getFunctionCall());
assertNotNull(choice2.getMessage().getContent());
}

}