Skip to content

GH-2737 Returning logprobs in generation metadata when requested #2750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,9 @@ private Generation buildGeneration(Choice choice, Map<String, Object> metadata,
generationMetadataBuilder.metadata("audioId", audioOutput.id());
generationMetadataBuilder.metadata("audioExpiresAt", audioOutput.expiresAt());
}
else if (Boolean.TRUE.equals(request.logprobs())) {
generationMetadataBuilder.metadata("logprobs", choice.logprobs());
}

var assistantMessage = new AssistantMessage(textContent, metadata, toolCalls, media);
return new Generation(assistantMessage, generationMetadataBuilder.build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders;
import org.springframework.beans.factory.annotation.Autowired;
Expand Down Expand Up @@ -73,7 +74,7 @@ void resetMockServer() {
@Test
void aiResponseContainsAiMetadata() {

prepareMock();
prepareMock(false);

Prompt prompt = new Prompt("Reach for the sky.");

Expand Down Expand Up @@ -118,13 +119,32 @@ void aiResponseContainsAiMetadata() {

response.getResults().forEach(generation -> {
ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata();
var logprobs = chatGenerationMetadata.get("logprobs");
assertThat(logprobs).isNull();
assertThat(chatGenerationMetadata).isNotNull();
assertThat(chatGenerationMetadata.getFinishReason()).isEqualTo("STOP");
assertThat(chatGenerationMetadata.getContentFilters()).isEmpty();
});
}

private void prepareMock() {
@Test
void aiResponseContainsAiLogprobsMetadata() {

prepareMock(true);

Prompt prompt = new Prompt("Reach for the sky.", new OpenAiChatOptions.Builder().logprobs(true).build());

ChatResponse response = this.openAiChatClient.call(prompt);

assertThat(response).isNotNull();
assertThat(response.getResult()).isNotNull();
assertThat(response.getResult().getMetadata()).isNotNull();

var logprobs = response.getResult().getMetadata().get("logprobs");
assertThat(logprobs).isNotNull().isInstanceOf(OpenAiApi.LogProbs.class);
}

private void prepareMock(boolean includeLogprobs) {

HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.set(OpenAiApiResponseHeaders.REQUESTS_LIMIT_HEADER.getName(), "4000");
Expand All @@ -137,34 +157,58 @@ private void prepareMock() {
this.server.expect(requestTo(StringContains.containsString("/v1/chat/completions")))
.andExpect(method(HttpMethod.POST))
.andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY))
.andRespond(withSuccess(getJson(), MediaType.APPLICATION_JSON).headers(httpHeaders));
.andRespond(withSuccess(getJson(includeLogprobs), MediaType.APPLICATION_JSON).headers(httpHeaders));

}

private String getJson() {
private String getBaseJson() {
return """
{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-3.5-turbo-0613",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "I surrender!"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
}
{
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1677652288,
"model": "gpt-3.5-turbo-0613",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "I surrender!"
},
%s
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
}
""";
}

private String getJson(boolean includeLogprobs) {
if (includeLogprobs) {
String logprobs = """
"logprobs" : {
"content" : [ {
"token" : "I",
"logprob" : -0.029507114,
"bytes" : [ 73 ],
"top_logprobs" : [ ]
}, {
"token" : " surrender!",
"logprob" : -0.061970375,
"bytes" : [ 32, 115, 117, 114, 114, 101, 110, 100, 101, 114, 33 ],
"top_logprobs" : [ ]
} ]
},
""";
return String.format(getBaseJson(), logprobs);
}

return String.format(getBaseJson(), "");
}

@SpringBootConfiguration
static class Config {

Expand Down