Skip to content

Commit d86425a

Browse files
authored
Allow structured outputs via function calling (sashabaranov#828)
1 parent dd7f582 commit d86425a

File tree

3 files changed

+103
-0
lines changed

3 files changed

+103
-0
lines changed

Diff for: api_integration_test.go

+76
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,79 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) {
239239
}
240240
}
241241
}
242+
243+
func TestChatCompletionStructuredOutputsFunctionCalling(t *testing.T) {
244+
apiToken := os.Getenv("OPENAI_TOKEN")
245+
if apiToken == "" {
246+
t.Skip("Skipping testing against production OpenAI API. Set OPENAI_TOKEN environment variable to enable it.")
247+
}
248+
249+
var err error
250+
c := openai.NewClient(apiToken)
251+
ctx := context.Background()
252+
253+
resp, err := c.CreateChatCompletion(
254+
ctx,
255+
openai.ChatCompletionRequest{
256+
Model: openai.GPT4oMini,
257+
Messages: []openai.ChatCompletionMessage{
258+
{
259+
Role: openai.ChatMessageRoleSystem,
260+
Content: "Please enter a string, and we will convert it into the following naming conventions:" +
261+
"1. PascalCase: Each word starts with an uppercase letter, with no spaces or separators." +
262+
"2. CamelCase: The first word starts with a lowercase letter, " +
263+
"and subsequent words start with an uppercase letter, with no spaces or separators." +
264+
"3. KebabCase: All letters are lowercase, with words separated by hyphens `-`." +
265+
"4. SnakeCase: All letters are lowercase, with words separated by underscores `_`.",
266+
},
267+
{
268+
Role: openai.ChatMessageRoleUser,
269+
Content: "Hello World",
270+
},
271+
},
272+
Tools: []openai.Tool{
273+
{
274+
Type: openai.ToolTypeFunction,
275+
Function: &openai.FunctionDefinition{
276+
Name: "display_cases",
277+
Strict: true,
278+
Parameters: &jsonschema.Definition{
279+
Type: jsonschema.Object,
280+
Properties: map[string]jsonschema.Definition{
281+
"PascalCase": {
282+
Type: jsonschema.String,
283+
},
284+
"CamelCase": {
285+
Type: jsonschema.String,
286+
},
287+
"KebabCase": {
288+
Type: jsonschema.String,
289+
},
290+
"SnakeCase": {
291+
Type: jsonschema.String,
292+
},
293+
},
294+
Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"},
295+
AdditionalProperties: false,
296+
},
297+
},
298+
},
299+
},
300+
ToolChoice: openai.ToolChoice{
301+
Type: openai.ToolTypeFunction,
302+
Function: openai.ToolFunction{
303+
Name: "display_cases",
304+
},
305+
},
306+
},
307+
)
308+
checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) returned error")
309+
var result = make(map[string]string)
310+
err = json.Unmarshal([]byte(resp.Choices[0].Message.ToolCalls[0].Function.Arguments), &result)
311+
checks.NoError(t, err, "CreateChatCompletion (use structured outputs response) unmarshal error")
312+
for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} {
313+
if _, ok := result[key]; !ok {
314+
t.Errorf("key:%s does not exist.", key)
315+
}
316+
}
317+
}

Diff for: chat.go

+1
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ type ToolFunction struct {
264264
type FunctionDefinition struct {
265265
Name string `json:"name"`
266266
Description string `json:"description,omitempty"`
267+
Strict bool `json:"strict,omitempty"`
267268
// Parameters is an object describing the function.
268269
// You can pass json.RawMessage to describe the schema,
269270
// or you can pass in a struct which serializes to the proper JSON schema.

Diff for: chat_test.go

+26
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,32 @@ func TestChatCompletionsFunctions(t *testing.T) {
277277
})
278278
checks.NoError(t, err, "CreateChatCompletion with functions error")
279279
})
280+
t.Run("StructuredOutputs", func(t *testing.T) {
281+
type testMessage struct {
282+
Count int `json:"count"`
283+
Words []string `json:"words"`
284+
}
285+
msg := testMessage{
286+
Count: 2,
287+
Words: []string{"hello", "world"},
288+
}
289+
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
290+
MaxTokens: 5,
291+
Model: openai.GPT3Dot5Turbo0613,
292+
Messages: []openai.ChatCompletionMessage{
293+
{
294+
Role: openai.ChatMessageRoleUser,
295+
Content: "Hello!",
296+
},
297+
},
298+
Functions: []openai.FunctionDefinition{{
299+
Name: "test",
300+
Strict: true,
301+
Parameters: &msg,
302+
}},
303+
})
304+
checks.NoError(t, err, "CreateChatCompletion with functions error")
305+
})
280306
}
281307

282308
func TestAzureChatCompletions(t *testing.T) {

0 commit comments

Comments
 (0)