From c04824df741325ab8ec6b9f4b3dc17d31164904d Mon Sep 17 00:00:00 2001 From: Leonardo Emili Date: Tue, 3 Oct 2023 22:30:35 +0000 Subject: [PATCH 1/2] Unwrap SocketTimeoutException and update tests accordingly --- .../main/java/example/OpenAiApiFunctionsExample.java | 3 ++- .../theokanning/openai/service/OpenAiService.java | 12 ++++++++++-- .../openai/service/ChatCompletionTest.java | 5 +++-- 3 files changed, 15 insertions(+), 5 deletions(-) diff --git a/example/src/main/java/example/OpenAiApiFunctionsExample.java b/example/src/main/java/example/OpenAiApiFunctionsExample.java index 954b9104..50c29160 100644 --- a/example/src/main/java/example/OpenAiApiFunctionsExample.java +++ b/example/src/main/java/example/OpenAiApiFunctionsExample.java @@ -7,6 +7,7 @@ import com.theokanning.openai.service.FunctionExecutor; import com.theokanning.openai.service.OpenAiService; +import java.net.SocketTimeoutException; import java.util.*; class OpenAiApiFunctionsExample { @@ -38,7 +39,7 @@ public WeatherResponse(String location, WeatherUnit unit, int temperature, Strin } } - public static void main(String... args) { + public static void main(String... args) throws SocketTimeoutException { String token = System.getenv("OPENAI_TOKEN"); OpenAiService service = new OpenAiService(token); diff --git a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java index 7114531b..28df3302 100644 --- a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java +++ b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java @@ -49,6 +49,7 @@ import javax.validation.constraints.NotNull; import java.io.IOException; +import java.net.SocketTimeoutException; import java.time.Duration; import java.time.LocalDate; import java.util.List; @@ -134,8 +135,15 @@ public Flowable streamCompletion(CompletionRequest request) { return stream(api.createCompletionStream(request), CompletionChunk.class); } - public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) { - return execute(api.createChatCompletion(request)); + public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) throws SocketTimeoutException{ + try { + return execute(api.createChatCompletion(request)); + } catch (RuntimeException e) { + if (e.getCause() != null && e.getCause() instanceof SocketTimeoutException && e.getCause().getMessage() == "hello world") + throw new SocketTimeoutException(e.getCause().getMessage()); + else + throw e; + } } public Flowable streamChatCompletion(ChatCompletionRequest request) { diff --git a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java index 3d26bf03..eab0af41 100644 --- a/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java +++ b/service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java @@ -7,6 +7,7 @@ import com.theokanning.openai.completion.chat.*; import org.junit.jupiter.api.Test; +import java.net.SocketTimeoutException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -47,7 +48,7 @@ public WeatherResponse(String location, WeatherUnit unit, int temperature, Strin OpenAiService service = new OpenAiService(token); @Test - void createChatCompletion() { + void createChatCompletion() throws SocketTimeoutException { final List messages = new ArrayList<>(); final ChatMessage systemMessage = new ChatMessage(ChatMessageRole.SYSTEM.value(), "You are a dog and will speak as such."); messages.add(systemMessage); @@ -88,7 +89,7 @@ void streamChatCompletion() { } @Test - void createChatCompletionWithFunctions() { + void createChatCompletionWithFunctions() throws SocketTimeoutException { final List functions = Collections.singletonList(ChatFunction.builder() .name("get_weather") .description("Get the current weather in a given location") From 7b63caadf6d6acf225aabce36f1ef49c843d6f12 Mon Sep 17 00:00:00 2001 From: Leonardo Emili Date: Tue, 3 Oct 2023 22:37:57 +0000 Subject: [PATCH 2/2] Remove debug flag --- .../main/java/com/theokanning/openai/service/OpenAiService.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java index 28df3302..75f4d3c4 100644 --- a/service/src/main/java/com/theokanning/openai/service/OpenAiService.java +++ b/service/src/main/java/com/theokanning/openai/service/OpenAiService.java @@ -139,7 +139,7 @@ public ChatCompletionResult createChatCompletion(ChatCompletionRequest request) try { return execute(api.createChatCompletion(request)); } catch (RuntimeException e) { - if (e.getCause() != null && e.getCause() instanceof SocketTimeoutException && e.getCause().getMessage() == "hello world") + if (e.getCause() != null && e.getCause() instanceof SocketTimeoutException) throw new SocketTimeoutException(e.getCause().getMessage()); else throw e;