Skip to content
This repository was archived by the owner on Jun 6, 2024. It is now read-only.

feat(assistant-function): change the function define and a test for function #415

Merged
merged 1 commit into from
Nov 21, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package com.theokanning.openai.assistants;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.Map;

/**
* @description:
* @author: vacuity
* @create: 2023-11-20 10:09
**/


@Builder
@NoArgsConstructor
@AllArgsConstructor
@Data
public class AssistantFunction {

private String description;

private String name;

private Map<String, Object> parameters;
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.theokanning.openai.assistants;

import com.theokanning.openai.completion.chat.ChatFunction;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@@ -17,5 +16,5 @@ public class Tool {
/**
* Function definition, only used if type is "function"
*/
ChatFunction function;
AssistantFunction function;
}
Original file line number Diff line number Diff line change
@@ -22,5 +22,5 @@
public class SubmitToolOutputsRequest {

@JsonProperty("tool_outputs")
private List<SubmitToolOutputRequestItem> tool_outputs;
private List<SubmitToolOutputRequestItem> toolOutputs;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package com.theokanning.openai.service;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.theokanning.openai.ListSearchParameters;
import com.theokanning.openai.OpenAiResponse;
import com.theokanning.openai.assistants.Assistant;
import com.theokanning.openai.assistants.AssistantFunction;
import com.theokanning.openai.assistants.AssistantRequest;
import com.theokanning.openai.assistants.AssistantToolsEnum;
import com.theokanning.openai.assistants.Tool;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatFunction;
import com.theokanning.openai.completion.chat.ChatFunctionCall;
import com.theokanning.openai.messages.Message;
import com.theokanning.openai.messages.MessageRequest;
import com.theokanning.openai.runs.RequiredAction;
import com.theokanning.openai.runs.Run;
import com.theokanning.openai.runs.RunCreateRequest;
import com.theokanning.openai.runs.RunStep;
import com.theokanning.openai.runs.SubmitToolOutputRequestItem;
import com.theokanning.openai.runs.SubmitToolOutputs;
import com.theokanning.openai.runs.SubmitToolOutputsRequest;
import com.theokanning.openai.runs.ToolCall;
import com.theokanning.openai.threads.Thread;
import com.theokanning.openai.threads.ThreadRequest;
import com.theokanning.openai.utils.TikTokensUtil;
import org.junit.jupiter.api.Test;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

class AssistantFunctionTest {
String token = System.getenv("OPENAI_TOKEN");
OpenAiService service = new OpenAiService(token, Duration.ofMinutes(1));

@Test
void createRetrieveRun() throws JsonProcessingException {

ObjectMapper mapper = new ObjectMapper();
mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL);
mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE);
mapper.addMixIn(ChatFunction.class, ChatFunctionMixIn.class);
mapper.addMixIn(ChatCompletionRequest.class, ChatCompletionRequestMixIn.class);
mapper.addMixIn(ChatFunctionCall.class, ChatFunctionCallMixIn.class);

String funcDef = "{\n" +
" \"type\": \"object\",\n" +
" \"properties\": {\n" +
" \"location\": {\n" +
" \"type\": \"string\",\n" +
" \"description\": \"The city and state, e.g. San Francisco, CA\"\n" +
" },\n" +
" \"unit\": {\n" +
" \"type\": \"string\",\n" +
" \"enum\": [\"celsius\", \"fahrenheit\"]\n" +
" }\n" +
" },\n" +
" \"required\": [\"location\"]\n" +
"}";
Map<String, Object> funcParameters = mapper.readValue(funcDef, new TypeReference<Map<String, Object>>() {});
AssistantFunction function = AssistantFunction.builder()
.name("weather_reporter")
.description("Get the current weather of a location")
.parameters(funcParameters)
.build();

List<Tool> toolList = new ArrayList<>();
Tool funcTool = new Tool(AssistantToolsEnum.FUNCTION, function);
toolList.add(funcTool);


AssistantRequest assistantRequest = AssistantRequest.builder()
.model(TikTokensUtil.ModelEnum.GPT_4_1106_preview.getName())
.name("MATH_TUTOR")
.instructions("You are a personal Math Tutor.")
.tools(toolList)
.build();
Assistant assistant = service.createAssistant(assistantRequest);

ThreadRequest threadRequest = ThreadRequest.builder()
.build();
Thread thread = service.createThread(threadRequest);

MessageRequest messageRequest = MessageRequest.builder()
.content("What's the weather of Xiamen?")
.build();

Message message = service.createMessage(thread.getId(), messageRequest);

RunCreateRequest runCreateRequest = RunCreateRequest.builder()
.assistantId(assistant.getId())
.build();

Run run = service.createRun(thread.getId(), runCreateRequest);
assertNotNull(run);

Run retrievedRun = service.retrieveRun(thread.getId(), run.getId());
while (!(retrievedRun.getStatus().equals("completed"))
&& !(retrievedRun.getStatus().equals("failed"))
&& !(retrievedRun.getStatus().equals("requires_action"))){
retrievedRun = service.retrieveRun(thread.getId(), run.getId());
}
if (retrievedRun.getStatus().equals("requires_action")) {
RequiredAction requiredAction = retrievedRun.getRequiredAction();
System.out.println("requiredAction");
System.out.println(mapper.writeValueAsString(requiredAction));
List<ToolCall> toolCalls = requiredAction.getSubmitToolOutputs().getToolCalls();
ToolCall toolCall = toolCalls.get(0);
String toolCallId = toolCall.getId();

SubmitToolOutputRequestItem toolOutputRequestItem = SubmitToolOutputRequestItem.builder()
.toolCallId(toolCallId)
.output("sunny")
.build();
List<SubmitToolOutputRequestItem> toolOutputRequestItems = new ArrayList<>();
toolOutputRequestItems.add(toolOutputRequestItem);
SubmitToolOutputsRequest submitToolOutputsRequest = SubmitToolOutputsRequest.builder()
.toolOutputs(toolOutputRequestItems)
.build();
retrievedRun = service.submitToolOutputs(retrievedRun.getThreadId(), retrievedRun.getId(), submitToolOutputsRequest);

while (!(retrievedRun.getStatus().equals("completed"))
&& !(retrievedRun.getStatus().equals("failed"))
&& !(retrievedRun.getStatus().equals("requires_action"))){
retrievedRun = service.retrieveRun(thread.getId(), run.getId());
}

OpenAiResponse<Message> response = service.listMessages(thread.getId());

List<Message> messages = response.getData();

System.out.println(mapper.writeValueAsString(messages));

}
}
}