diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 7c036c67d3d..9a3884f468e 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -434,6 +434,9 @@ private Generation buildGeneration(Choice choice, Map metadata, generationMetadataBuilder.metadata("audioId", audioOutput.id()); generationMetadataBuilder.metadata("audioExpiresAt", audioOutput.expiresAt()); } + else if (Boolean.TRUE.equals(request.logprobs())) { + generationMetadataBuilder.metadata("logprobs", choice.logprobs()); + } var assistantMessage = new AssistantMessage(textContent, metadata, toolCalls, media); return new Generation(assistantMessage, generationMetadataBuilder.build()); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java index f606e16ad4d..1e9815513c3 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatModelWithChatResponseMetadataTests.java @@ -30,6 +30,7 @@ import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.OpenAiChatOptions; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.metadata.support.OpenAiApiResponseHeaders; import org.springframework.beans.factory.annotation.Autowired; @@ -73,7 +74,7 @@ void resetMockServer() { @Test void aiResponseContainsAiMetadata() { - prepareMock(); + prepareMock(false); Prompt prompt = new Prompt("Reach for the sky."); @@ -118,13 +119,32 @@ void aiResponseContainsAiMetadata() { response.getResults().forEach(generation -> { ChatGenerationMetadata chatGenerationMetadata = generation.getMetadata(); + var logprobs = chatGenerationMetadata.get("logprobs"); + assertThat(logprobs).isNull(); assertThat(chatGenerationMetadata).isNotNull(); assertThat(chatGenerationMetadata.getFinishReason()).isEqualTo("STOP"); assertThat(chatGenerationMetadata.getContentFilters()).isEmpty(); }); } - private void prepareMock() { + @Test + void aiResponseContainsAiLogprobsMetadata() { + + prepareMock(true); + + Prompt prompt = new Prompt("Reach for the sky.", new OpenAiChatOptions.Builder().logprobs(true).build()); + + ChatResponse response = this.openAiChatClient.call(prompt); + + assertThat(response).isNotNull(); + assertThat(response.getResult()).isNotNull(); + assertThat(response.getResult().getMetadata()).isNotNull(); + + var logprobs = response.getResult().getMetadata().get("logprobs"); + assertThat(logprobs).isNotNull().isInstanceOf(OpenAiApi.LogProbs.class); + } + + private void prepareMock(boolean includeLogprobs) { HttpHeaders httpHeaders = new HttpHeaders(); httpHeaders.set(OpenAiApiResponseHeaders.REQUESTS_LIMIT_HEADER.getName(), "4000"); @@ -137,34 +157,58 @@ private void prepareMock() { this.server.expect(requestTo(StringContains.containsString("/v1/chat/completions"))) .andExpect(method(HttpMethod.POST)) .andExpect(header(HttpHeaders.AUTHORIZATION, "Bearer " + TEST_API_KEY)) - .andRespond(withSuccess(getJson(), MediaType.APPLICATION_JSON).headers(httpHeaders)); + .andRespond(withSuccess(getJson(includeLogprobs), MediaType.APPLICATION_JSON).headers(httpHeaders)); } - private String getJson() { + private String getBaseJson() { return """ - { - "id": "chatcmpl-123", - "object": "chat.completion", - "created": 1677652288, - "model": "gpt-3.5-turbo-0613", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": "I surrender!" - }, - "finish_reason": "stop" - }], - "usage": { - "prompt_tokens": 9, - "completion_tokens": 12, - "total_tokens": 21 - } - } + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-3.5-turbo-0613", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "I surrender!" + }, + %s + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + } """; } + private String getJson(boolean includeLogprobs) { + if (includeLogprobs) { + String logprobs = """ + "logprobs" : { + "content" : [ { + "token" : "I", + "logprob" : -0.029507114, + "bytes" : [ 73 ], + "top_logprobs" : [ ] + }, { + "token" : " surrender!", + "logprob" : -0.061970375, + "bytes" : [ 32, 115, 117, 114, 114, 101, 110, 100, 101, 114, 33 ], + "top_logprobs" : [ ] + } ] + }, + """; + return String.format(getBaseJson(), logprobs); + } + + return String.format(getBaseJson(), ""); + } + @SpringBootConfiguration static class Config {