Skip to content

Commit a3bd256

Browse files
authored
Improve handling of JSON Schema in OpenAI API Response Context (#819)
* feat: add jsonschema.Validate and jsonschema.Unmarshal * fix Sanity check * remove slices.Contains * fix Sanity check * add SchemaWrapper * update api_integration_test.go * update method 'reflectSchema' to support 'omitempty' in JSON tag * add GenerateSchemaForType * update json_test.go * update `Warp` to `Wrap` * fix Sanity check * fix Sanity check * update api_internal_test.go * update README.md * update README.md * remove jsonschema.SchemaWrapper * remove jsonschema.SchemaWrapper * fix Sanity check * optimize code formatting
1 parent 5162adb commit a3bd256

File tree

7 files changed

+412
-30
lines changed

7 files changed

+412
-30
lines changed

README.md

+64
Original file line numberDiff line numberDiff line change
@@ -743,6 +743,70 @@ func main() {
743743
}
744744
```
745745
</details>
746+
747+
<details>
748+
<summary>Structured Outputs</summary>
749+
750+
```go
751+
package main
752+
753+
import (
754+
"context"
755+
"fmt"
756+
"log"
757+
758+
"github.com/sashabaranov/go-openai"
759+
"github.com/sashabaranov/go-openai/jsonschema"
760+
)
761+
762+
func main() {
763+
client := openai.NewClient("your token")
764+
ctx := context.Background()
765+
766+
type Result struct {
767+
Steps []struct {
768+
Explanation string `json:"explanation"`
769+
Output string `json:"output"`
770+
} `json:"steps"`
771+
FinalAnswer string `json:"final_answer"`
772+
}
773+
var result Result
774+
schema, err := jsonschema.GenerateSchemaForType(result)
775+
if err != nil {
776+
log.Fatalf("GenerateSchemaForType error: %v", err)
777+
}
778+
resp, err := client.CreateChatCompletion(ctx, openai.ChatCompletionRequest{
779+
Model: openai.GPT4oMini,
780+
Messages: []openai.ChatCompletionMessage{
781+
{
782+
Role: openai.ChatMessageRoleSystem,
783+
Content: "You are a helpful math tutor. Guide the user through the solution step by step.",
784+
},
785+
{
786+
Role: openai.ChatMessageRoleUser,
787+
Content: "how can I solve 8x + 7 = -23",
788+
},
789+
},
790+
ResponseFormat: &openai.ChatCompletionResponseFormat{
791+
Type: openai.ChatCompletionResponseFormatTypeJSONSchema,
792+
JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{
793+
Name: "math_reasoning",
794+
Schema: schema,
795+
Strict: true,
796+
},
797+
},
798+
})
799+
if err != nil {
800+
log.Fatalf("CreateChatCompletion error: %v", err)
801+
}
802+
err = schema.Unmarshal(resp.Choices[0].Message.Content, &result)
803+
if err != nil {
804+
log.Fatalf("Unmarshal schema error: %v", err)
805+
}
806+
fmt.Println(result)
807+
}
808+
```
809+
</details>
746810
See the `examples/` folder for more.
747811

748812
## Frequently Asked Questions

api_integration_test.go

+16-20
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ package openai_test
44

55
import (
66
"context"
7-
"encoding/json"
87
"errors"
98
"io"
109
"os"
@@ -190,6 +189,17 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) {
190189
c := openai.NewClient(apiToken)
191190
ctx := context.Background()
192191

192+
type MyStructuredResponse struct {
193+
PascalCase string `json:"pascal_case" required:"true" description:"PascalCase"`
194+
CamelCase string `json:"camel_case" required:"true" description:"CamelCase"`
195+
KebabCase string `json:"kebab_case" required:"true" description:"KebabCase"`
196+
SnakeCase string `json:"snake_case" required:"true" description:"SnakeCase"`
197+
}
198+
var result MyStructuredResponse
199+
schema, err := jsonschema.GenerateSchemaForType(result)
200+
if err != nil {
201+
t.Fatal("CreateChatCompletion (use json_schema response) GenerateSchemaForType error")
202+
}
193203
resp, err := c.CreateChatCompletion(
194204
ctx,
195205
openai.ChatCompletionRequest{
@@ -212,31 +222,17 @@ func TestChatCompletionResponseFormat_JSONSchema(t *testing.T) {
212222
ResponseFormat: &openai.ChatCompletionResponseFormat{
213223
Type: openai.ChatCompletionResponseFormatTypeJSONSchema,
214224
JSONSchema: &openai.ChatCompletionResponseFormatJSONSchema{
215-
Name: "cases",
216-
Schema: jsonschema.Definition{
217-
Type: jsonschema.Object,
218-
Properties: map[string]jsonschema.Definition{
219-
"PascalCase": jsonschema.Definition{Type: jsonschema.String},
220-
"CamelCase": jsonschema.Definition{Type: jsonschema.String},
221-
"KebabCase": jsonschema.Definition{Type: jsonschema.String},
222-
"SnakeCase": jsonschema.Definition{Type: jsonschema.String},
223-
},
224-
Required: []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"},
225-
AdditionalProperties: false,
226-
},
225+
Name: "cases",
226+
Schema: schema,
227227
Strict: true,
228228
},
229229
},
230230
},
231231
)
232232
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) returned error")
233-
var result = make(map[string]string)
234-
err = json.Unmarshal([]byte(resp.Choices[0].Message.Content), &result)
235-
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error")
236-
for _, key := range []string{"PascalCase", "CamelCase", "KebabCase", "SnakeCase"} {
237-
if _, ok := result[key]; !ok {
238-
t.Errorf("key:%s does not exist.", key)
239-
}
233+
if err == nil {
234+
err = schema.Unmarshal(resp.Choices[0].Message.Content, &result)
235+
checks.NoError(t, err, "CreateChatCompletion (use json_schema response) unmarshal error")
240236
}
241237
}
242238

chat.go

+4-6
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@ import (
55
"encoding/json"
66
"errors"
77
"net/http"
8-
9-
"github.com/sashabaranov/go-openai/jsonschema"
108
)
119

1210
// Chat message role defined by the OpenAI API.
@@ -187,10 +185,10 @@ type ChatCompletionResponseFormat struct {
187185
}
188186

189187
type ChatCompletionResponseFormatJSONSchema struct {
190-
Name string `json:"name"`
191-
Description string `json:"description,omitempty"`
192-
Schema jsonschema.Definition `json:"schema"`
193-
Strict bool `json:"strict"`
188+
Name string `json:"name"`
189+
Description string `json:"description,omitempty"`
190+
Schema json.Marshaler `json:"schema"`
191+
Strict bool `json:"strict"`
194192
}
195193

196194
// ChatCompletionRequest represents a request structure for chat completion API.

example_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func ExampleClient_CreateChatCompletionStream() {
5959
}
6060
defer stream.Close()
6161

62-
fmt.Printf("Stream response: ")
62+
fmt.Print("Stream response: ")
6363
for {
6464
var response openai.ChatCompletionStreamResponse
6565
response, err = stream.Recv()

jsonschema/json.go

+102-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,13 @@
44
// and/or pass in the schema in []byte format.
55
package jsonschema
66

7-
import "encoding/json"
7+
import (
8+
"encoding/json"
9+
"fmt"
10+
"reflect"
11+
"strconv"
12+
"strings"
13+
)
814

915
type DataType string
1016

@@ -42,14 +48,107 @@ type Definition struct {
4248
AdditionalProperties any `json:"additionalProperties,omitempty"`
4349
}
4450

45-
func (d Definition) MarshalJSON() ([]byte, error) {
51+
func (d *Definition) MarshalJSON() ([]byte, error) {
4652
if d.Properties == nil {
4753
d.Properties = make(map[string]Definition)
4854
}
4955
type Alias Definition
5056
return json.Marshal(struct {
5157
Alias
5258
}{
53-
Alias: (Alias)(d),
59+
Alias: (Alias)(*d),
5460
})
5561
}
62+
63+
func (d *Definition) Unmarshal(content string, v any) error {
64+
return VerifySchemaAndUnmarshal(*d, []byte(content), v)
65+
}
66+
67+
func GenerateSchemaForType(v any) (*Definition, error) {
68+
return reflectSchema(reflect.TypeOf(v))
69+
}
70+
71+
func reflectSchema(t reflect.Type) (*Definition, error) {
72+
var d Definition
73+
switch t.Kind() {
74+
case reflect.String:
75+
d.Type = String
76+
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
77+
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
78+
d.Type = Integer
79+
case reflect.Float32, reflect.Float64:
80+
d.Type = Number
81+
case reflect.Bool:
82+
d.Type = Boolean
83+
case reflect.Slice, reflect.Array:
84+
d.Type = Array
85+
items, err := reflectSchema(t.Elem())
86+
if err != nil {
87+
return nil, err
88+
}
89+
d.Items = items
90+
case reflect.Struct:
91+
d.Type = Object
92+
d.AdditionalProperties = false
93+
object, err := reflectSchemaObject(t)
94+
if err != nil {
95+
return nil, err
96+
}
97+
d = *object
98+
case reflect.Ptr:
99+
definition, err := reflectSchema(t.Elem())
100+
if err != nil {
101+
return nil, err
102+
}
103+
d = *definition
104+
case reflect.Invalid, reflect.Uintptr, reflect.Complex64, reflect.Complex128,
105+
reflect.Chan, reflect.Func, reflect.Interface, reflect.Map,
106+
reflect.UnsafePointer:
107+
return nil, fmt.Errorf("unsupported type: %s", t.Kind().String())
108+
default:
109+
}
110+
return &d, nil
111+
}
112+
113+
func reflectSchemaObject(t reflect.Type) (*Definition, error) {
114+
var d = Definition{
115+
Type: Object,
116+
AdditionalProperties: false,
117+
}
118+
properties := make(map[string]Definition)
119+
var requiredFields []string
120+
for i := 0; i < t.NumField(); i++ {
121+
field := t.Field(i)
122+
if !field.IsExported() {
123+
continue
124+
}
125+
jsonTag := field.Tag.Get("json")
126+
var required = true
127+
if jsonTag == "" {
128+
jsonTag = field.Name
129+
} else if strings.HasSuffix(jsonTag, ",omitempty") {
130+
jsonTag = strings.TrimSuffix(jsonTag, ",omitempty")
131+
required = false
132+
}
133+
134+
item, err := reflectSchema(field.Type)
135+
if err != nil {
136+
return nil, err
137+
}
138+
description := field.Tag.Get("description")
139+
if description != "" {
140+
item.Description = description
141+
}
142+
properties[jsonTag] = *item
143+
144+
if s := field.Tag.Get("required"); s != "" {
145+
required, _ = strconv.ParseBool(s)
146+
}
147+
if required {
148+
requiredFields = append(requiredFields, jsonTag)
149+
}
150+
}
151+
d.Required = requiredFields
152+
d.Properties = properties
153+
return &d, nil
154+
}

jsonschema/validate.go

+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package jsonschema
2+
3+
import (
4+
"encoding/json"
5+
"errors"
6+
)
7+
8+
func VerifySchemaAndUnmarshal(schema Definition, content []byte, v any) error {
9+
var data any
10+
err := json.Unmarshal(content, &data)
11+
if err != nil {
12+
return err
13+
}
14+
if !Validate(schema, data) {
15+
return errors.New("data validation failed against the provided schema")
16+
}
17+
return json.Unmarshal(content, &v)
18+
}
19+
20+
func Validate(schema Definition, data any) bool {
21+
switch schema.Type {
22+
case Object:
23+
return validateObject(schema, data)
24+
case Array:
25+
return validateArray(schema, data)
26+
case String:
27+
_, ok := data.(string)
28+
return ok
29+
case Number: // float64 and int
30+
_, ok := data.(float64)
31+
if !ok {
32+
_, ok = data.(int)
33+
}
34+
return ok
35+
case Boolean:
36+
_, ok := data.(bool)
37+
return ok
38+
case Integer:
39+
_, ok := data.(int)
40+
return ok
41+
case Null:
42+
return data == nil
43+
default:
44+
return false
45+
}
46+
}
47+
48+
func validateObject(schema Definition, data any) bool {
49+
dataMap, ok := data.(map[string]any)
50+
if !ok {
51+
return false
52+
}
53+
for _, field := range schema.Required {
54+
if _, exists := dataMap[field]; !exists {
55+
return false
56+
}
57+
}
58+
for key, valueSchema := range schema.Properties {
59+
value, exists := dataMap[key]
60+
if exists && !Validate(valueSchema, value) {
61+
return false
62+
} else if !exists && contains(schema.Required, key) {
63+
return false
64+
}
65+
}
66+
return true
67+
}
68+
69+
func validateArray(schema Definition, data any) bool {
70+
dataArray, ok := data.([]any)
71+
if !ok {
72+
return false
73+
}
74+
for _, item := range dataArray {
75+
if !Validate(*schema.Items, item) {
76+
return false
77+
}
78+
}
79+
return true
80+
}
81+
82+
func contains[S ~[]E, E comparable](s S, v E) bool {
83+
for i := range s {
84+
if v == s[i] {
85+
return true
86+
}
87+
}
88+
return false
89+
}

0 commit comments

Comments
 (0)