Skip to content

Commit 5e0b757

Browse files
authored
.Net: Added logprobs property to OpenAIPromptExecutionSettings (#6300)
### Motivation and Context <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> Fixes: #6277 https://platform.openai.com/docs/api-reference/chat/create#chat-create-logprobs ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [x] I didn't break anyone 😄
1 parent 3e19785 commit 5e0b757

File tree

6 files changed

+100
-8
lines changed

6 files changed

+100
-8
lines changed

dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs

+4-2
Original file line numberDiff line numberDiff line change
@@ -1050,7 +1050,7 @@ private static CompletionsOptions CreateCompletionsOptions(string text, OpenAIPr
10501050
Echo = false,
10511051
ChoicesPerPrompt = executionSettings.ResultsPerPrompt,
10521052
GenerationSampleCount = executionSettings.ResultsPerPrompt,
1053-
LogProbabilityCount = null,
1053+
LogProbabilityCount = executionSettings.TopLogprobs,
10541054
User = executionSettings.User,
10551055
DeploymentName = deploymentOrModelName
10561056
};
@@ -1102,7 +1102,9 @@ private ChatCompletionsOptions CreateChatCompletionsOptions(
11021102
ChoiceCount = executionSettings.ResultsPerPrompt,
11031103
DeploymentName = deploymentOrModelName,
11041104
Seed = executionSettings.Seed,
1105-
User = executionSettings.User
1105+
User = executionSettings.User,
1106+
LogProbabilitiesPerToken = executionSettings.TopLogprobs,
1107+
EnableLogProbabilities = executionSettings.Logprobs
11061108
};
11071109

11081110
switch (executionSettings.ResponseFormat)

dotnet/src/Connectors/Connectors.OpenAI/OpenAIPromptExecutionSettings.cs

+38-1
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,39 @@ public string? User
254254
}
255255
}
256256

257+
/// <summary>
258+
/// Whether to return log probabilities of the output tokens or not.
259+
/// If true, returns the log probabilities of each output token returned in the `content` of `message`.
260+
/// </summary>
261+
[Experimental("SKEXP0010")]
262+
[JsonPropertyName("logprobs")]
263+
public bool? Logprobs
264+
{
265+
get => this._logprobs;
266+
267+
set
268+
{
269+
this.ThrowIfFrozen();
270+
this._logprobs = value;
271+
}
272+
}
273+
274+
/// <summary>
275+
/// An integer specifying the number of most likely tokens to return at each token position, each with an associated log probability.
276+
/// </summary>
277+
[Experimental("SKEXP0010")]
278+
[JsonPropertyName("top_logprobs")]
279+
public int? TopLogprobs
280+
{
281+
get => this._topLogprobs;
282+
283+
set
284+
{
285+
this.ThrowIfFrozen();
286+
this._topLogprobs = value;
287+
}
288+
}
289+
257290
/// <inheritdoc/>
258291
public override void Freeze()
259292
{
@@ -294,7 +327,9 @@ public override PromptExecutionSettings Clone()
294327
TokenSelectionBiases = this.TokenSelectionBiases is not null ? new Dictionary<int, int>(this.TokenSelectionBiases) : null,
295328
ToolCallBehavior = this.ToolCallBehavior,
296329
User = this.User,
297-
ChatSystemPrompt = this.ChatSystemPrompt
330+
ChatSystemPrompt = this.ChatSystemPrompt,
331+
Logprobs = this.Logprobs,
332+
TopLogprobs = this.TopLogprobs
298333
};
299334
}
300335

@@ -370,6 +405,8 @@ public static OpenAIPromptExecutionSettings FromExecutionSettingsWithData(Prompt
370405
private ToolCallBehavior? _toolCallBehavior;
371406
private string? _user;
372407
private string? _chatSystemPrompt;
408+
private bool? _logprobs;
409+
private int? _topLogprobs;
373410

374411
#endregion
375412
}

dotnet/src/Connectors/Connectors.UnitTests/OpenAI/ChatCompletion/AzureOpenAIChatCompletionServiceTests.cs

+5-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,9 @@ public async Task GetChatMessageContentsHandlesSettingsCorrectlyAsync()
161161
ResultsPerPrompt = 5,
162162
Seed = 567,
163163
TokenSelectionBiases = new Dictionary<int, int> { { 2, 3 } },
164-
StopSequences = ["stop_sequence"]
164+
StopSequences = ["stop_sequence"],
165+
Logprobs = true,
166+
TopLogprobs = 5
165167
};
166168

167169
var chatHistory = new ChatHistory();
@@ -218,6 +220,8 @@ public async Task GetChatMessageContentsHandlesSettingsCorrectlyAsync()
218220
Assert.Equal(567, content.GetProperty("seed").GetInt32());
219221
Assert.Equal(3, content.GetProperty("logit_bias").GetProperty("2").GetInt32());
220222
Assert.Equal("stop_sequence", content.GetProperty("stop")[0].GetString());
223+
Assert.True(content.GetProperty("logprobs").GetBoolean());
224+
Assert.Equal(5, content.GetProperty("top_logprobs").GetInt32());
221225
}
222226

223227
[Theory]

dotnet/src/Connectors/Connectors.UnitTests/OpenAI/OpenAIPromptExecutionSettingsTests.cs

+17-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ public void ItCreatesOpenAIExecutionSettingsWithCorrectDefaults()
3030
Assert.Equal(1, executionSettings.ResultsPerPrompt);
3131
Assert.Null(executionSettings.StopSequences);
3232
Assert.Null(executionSettings.TokenSelectionBiases);
33+
Assert.Null(executionSettings.TopLogprobs);
34+
Assert.Null(executionSettings.Logprobs);
3335
Assert.Equal(128, executionSettings.MaxTokens);
3436
}
3537

@@ -47,6 +49,8 @@ public void ItUsesExistingOpenAIExecutionSettings()
4749
StopSequences = new string[] { "foo", "bar" },
4850
ChatSystemPrompt = "chat system prompt",
4951
MaxTokens = 128,
52+
Logprobs = true,
53+
TopLogprobs = 5,
5054
TokenSelectionBiases = new Dictionary<int, int>() { { 1, 2 }, { 3, 4 } },
5155
};
5256

@@ -97,6 +101,8 @@ public void ItCreatesOpenAIExecutionSettingsFromExtraPropertiesSnakeCase()
97101
{ "max_tokens", 128 },
98102
{ "token_selection_biases", new Dictionary<int, int>() { { 1, 2 }, { 3, 4 } } },
99103
{ "seed", 123456 },
104+
{ "logprobs", true },
105+
{ "top_logprobs", 5 },
100106
}
101107
};
102108

@@ -105,7 +111,6 @@ public void ItCreatesOpenAIExecutionSettingsFromExtraPropertiesSnakeCase()
105111

106112
// Assert
107113
AssertExecutionSettings(executionSettings);
108-
Assert.Equal(executionSettings.Seed, 123456);
109114
}
110115

111116
[Fact]
@@ -124,7 +129,10 @@ public void ItCreatesOpenAIExecutionSettingsFromExtraPropertiesAsStrings()
124129
{ "stop_sequences", new [] { "foo", "bar" } },
125130
{ "chat_system_prompt", "chat system prompt" },
126131
{ "max_tokens", "128" },
127-
{ "token_selection_biases", new Dictionary<string, string>() { { "1", "2" }, { "3", "4" } } }
132+
{ "token_selection_biases", new Dictionary<string, string>() { { "1", "2" }, { "3", "4" } } },
133+
{ "seed", 123456 },
134+
{ "logprobs", true },
135+
{ "top_logprobs", 5 }
128136
}
129137
};
130138

@@ -149,7 +157,10 @@ public void ItCreatesOpenAIExecutionSettingsFromJsonSnakeCase()
149157
"stop_sequences": [ "foo", "bar" ],
150158
"chat_system_prompt": "chat system prompt",
151159
"token_selection_biases": { "1": 2, "3": 4 },
152-
"max_tokens": 128
160+
"max_tokens": 128,
161+
"seed": 123456,
162+
"logprobs": true,
163+
"top_logprobs": 5
153164
}
154165
""";
155166
var actualSettings = JsonSerializer.Deserialize<PromptExecutionSettings>(json);
@@ -255,5 +266,8 @@ private static void AssertExecutionSettings(OpenAIPromptExecutionSettings execut
255266
Assert.Equal("chat system prompt", executionSettings.ChatSystemPrompt);
256267
Assert.Equal(new Dictionary<int, int>() { { 1, 2 }, { 3, 4 } }, executionSettings.TokenSelectionBiases);
257268
Assert.Equal(128, executionSettings.MaxTokens);
269+
Assert.Equal(123456, executionSettings.Seed);
270+
Assert.Equal(true, executionSettings.Logprobs);
271+
Assert.Equal(5, executionSettings.TopLogprobs);
258272
}
259273
}

dotnet/src/Connectors/Connectors.UnitTests/OpenAI/TextGeneration/AzureOpenAITextGenerationServiceTests.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ public async Task GetTextContentsHandlesSettingsCorrectlyAsync()
126126
PresencePenalty = 1.2,
127127
ResultsPerPrompt = 5,
128128
TokenSelectionBiases = new Dictionary<int, int> { { 2, 3 } },
129-
StopSequences = ["stop_sequence"]
129+
StopSequences = ["stop_sequence"],
130+
TopLogprobs = 5
130131
};
131132

132133
this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(HttpStatusCode.OK)
@@ -154,6 +155,7 @@ public async Task GetTextContentsHandlesSettingsCorrectlyAsync()
154155
Assert.Equal(5, content.GetProperty("best_of").GetInt32());
155156
Assert.Equal(3, content.GetProperty("logit_bias").GetProperty("2").GetInt32());
156157
Assert.Equal("stop_sequence", content.GetProperty("stop")[0].GetString());
158+
Assert.Equal(5, content.GetProperty("logprobs").GetInt32());
157159
}
158160

159161
[Fact]

dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAICompletionTests.cs

+33
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
using System.Text.Json;
1010
using System.Threading;
1111
using System.Threading.Tasks;
12+
using Azure.AI.OpenAI;
1213
using Microsoft.Extensions.Configuration;
1314
using Microsoft.Extensions.DependencyInjection;
1415
using Microsoft.Extensions.Http.Resilience;
@@ -504,6 +505,38 @@ public async Task SemanticKernelVersionHeaderIsSentAsync()
504505
Assert.True(httpHeaderHandler.RequestHeaders.TryGetValues("Semantic-Kernel-Version", out var values));
505506
}
506507

508+
[Theory(Skip = "This test is for manual verification.")]
509+
[InlineData(null, null)]
510+
[InlineData(false, null)]
511+
[InlineData(true, 2)]
512+
[InlineData(true, 5)]
513+
public async Task LogProbsDataIsReturnedWhenRequestedAsync(bool? logprobs, int? topLogprobs)
514+
{
515+
// Arrange
516+
var settings = new OpenAIPromptExecutionSettings { Logprobs = logprobs, TopLogprobs = topLogprobs };
517+
518+
this._kernelBuilder.Services.AddSingleton<ILoggerFactory>(this._logger);
519+
var builder = this._kernelBuilder;
520+
this.ConfigureAzureOpenAIChatAsText(builder);
521+
Kernel target = builder.Build();
522+
523+
// Act
524+
var result = await target.InvokePromptAsync("Hi, can you help me today?", new(settings));
525+
526+
var logProbabilityInfo = result.Metadata?["LogProbabilityInfo"] as ChatChoiceLogProbabilityInfo;
527+
528+
// Assert
529+
if (logprobs is true)
530+
{
531+
Assert.NotNull(logProbabilityInfo);
532+
Assert.Equal(topLogprobs, logProbabilityInfo.TokenLogProbabilityResults[0].TopLogProbabilityEntries.Count);
533+
}
534+
else
535+
{
536+
Assert.Null(logProbabilityInfo);
537+
}
538+
}
539+
507540
#region internals
508541

509542
private readonly XunitLogger<Kernel> _logger = new(output);

0 commit comments

Comments
 (0)