Skip to content

Commit e847d66

Browse files
committed
Added response_format capabilities and integration test covering it (TheoKanning#388)
1 parent 96a169a commit e847d66

File tree

2 files changed

+73
-1
lines changed

2 files changed

+73
-1
lines changed

api/src/main/java/com/theokanning/openai/completion/chat/ChatCompletionRequest.java

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,19 @@ public class ChatCompletionRequest {
3434
@JsonProperty("frequency_penalty")
3535
Double frequencyPenalty;
3636

37+
/**
38+
* <p>An object specifying the format that the model must output.</p>
39+
*
40+
* <p>Setting to { "type": "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON.</p>
41+
*
42+
* <p><b>Important:</b> when using JSON mode, you must also instruct the model to produce JSON yourself via a system or user message.
43+
* Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting
44+
* in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if
45+
* finish_reason="length", which indicates the generation exceeded max_tokens or the conversation exceeded the max context length.</p>
46+
*/
47+
@JsonProperty("response_format")
48+
ResponseFormat responseFormat;
49+
3750
/**
3851
* Accepts a json object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100
3952
* to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will
@@ -117,4 +130,18 @@ public static ChatCompletionRequestFunctionCall of(String name) {
117130
}
118131

119132
}
133+
134+
@Data
135+
@Builder
136+
@AllArgsConstructor
137+
@NoArgsConstructor
138+
public static class ResponseFormat {
139+
String type;
140+
141+
public static ResponseFormat of(String type) {
142+
return new ResponseFormat(type);
143+
}
144+
145+
}
146+
120147
}

service/src/test/java/com/theokanning/openai/service/ChatCompletionTest.java

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22

33
import com.fasterxml.jackson.annotation.JsonProperty;
44
import com.fasterxml.jackson.annotation.JsonPropertyDescription;
5+
import com.fasterxml.jackson.core.JsonParser;
56
import com.fasterxml.jackson.databind.JsonNode;
7+
import com.fasterxml.jackson.databind.ObjectMapper;
68
import com.fasterxml.jackson.databind.node.ObjectNode;
79
import com.theokanning.openai.completion.chat.*;
810
import org.junit.jupiter.api.Test;
911

12+
import java.io.IOException;
1013
import java.util.*;
1114

1215
import static org.junit.jupiter.api.Assertions.*;
@@ -23,7 +26,7 @@ static class Weather {
2326
}
2427

2528
enum WeatherUnit {
26-
CELSIUS, FAHRENHEIT;
29+
CELSIUS, FAHRENHEIT
2730
}
2831

2932
static class WeatherResponse {
@@ -300,4 +303,46 @@ void streamChatCompletionWithDynamicFunctions() {
300303
assertNotNull(accumulatedMessage.getFunctionCall().getArguments().get("unit"));
301304
}
302305

306+
@Test
307+
void streamChatCompletionWithJsonResponseFormat() {
308+
final List<ChatMessage> messages = new ArrayList<>();
309+
310+
// The system message is deliberately vague in order to not give too much of a direction of how response should look like.
311+
// The main gist there is that chat competition should always contain JSON content.
312+
final ChatMessage systemMessage = new ChatMessage(
313+
ChatMessageRole.SYSTEM.value(),
314+
"You are a dog and will speak as such - but please do it in JSON."
315+
);
316+
317+
messages.add(systemMessage);
318+
319+
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
320+
.builder()
321+
.model("gpt-4-1106-preview")
322+
.messages(messages)
323+
.n(1)
324+
.maxTokens(256)
325+
.responseFormat(ChatCompletionRequest.ResponseFormat.of("json_object"))
326+
.build();
327+
328+
ChatCompletionResult chatCompletion = service.createChatCompletion(chatCompletionRequest);
329+
330+
ChatCompletionChoice chatCompletionChoice = chatCompletion.getChoices().get(0);
331+
String expectedJsonContent = chatCompletionChoice.getMessage().getContent();
332+
333+
assertTrue(isValidJSON(expectedJsonContent), "Invalid JSON response:\n\n" + expectedJsonContent);
334+
}
335+
336+
private boolean isValidJSON(String json) {
337+
try (final JsonParser parser = new ObjectMapper().createParser(json)) {
338+
while (parser.nextToken() != null) {
339+
// Just try to read all tokens in order to verify whether this is valid json.
340+
}
341+
return true;
342+
} catch (IOException ioe) {
343+
ioe.printStackTrace();
344+
return false;
345+
}
346+
}
347+
303348
}

0 commit comments

Comments
 (0)