Skip to content

Commit 798d731

Browse files
author
zhouyy
committed
feat:update
1 parent ebb0e43 commit 798d731

11 files changed

+178
-67
lines changed

assistant.go

+8-1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ type AssistantRequest struct {
7676
ResponseFormat any `json:"response_format,omitempty"`
7777
Temperature *float32 `json:"temperature,omitempty"`
7878
TopP *float32 `json:"top_p,omitempty"`
79+
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
80+
ExtraQuery map[string]string `json:"extra_query,omitempty"`
81+
ExtraBody map[string]any `json:"extra_body,omitempty"`
7982
}
8083

8184
// MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases
@@ -137,7 +140,11 @@ type AssistantFilesList struct {
137140
// CreateAssistant creates a new assistant.
138141
func (c *Client) CreateAssistant(ctx context.Context, request AssistantRequest) (response Assistant, err error) {
139142
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request),
140-
withBetaAssistantVersion(c.config.AssistantVersion))
143+
withBetaAssistantVersion(c.config.AssistantVersion),
144+
withExtraHeaders(request.ExtraHeaders),
145+
withExtraQuery(request.ExtraQuery),
146+
withExtraBody(request.ExtraBody),
147+
)
141148
if err != nil {
142149
return
143150
}

audio.go

+6
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ type AudioRequest struct {
4949
Language string // Only for transcription.
5050
Format AudioResponseFormat
5151
TimestampGranularities []TranscriptionTimestampGranularity // Only for transcription.
52+
ExtraHeaders map[string]string
53+
ExtraQuery map[string]string
54+
ExtraBody map[string]any
5255
}
5356

5457
// AudioResponse represents a response structure for audio API.
@@ -128,6 +131,9 @@ func (c *Client) callAudioAPI(
128131
c.fullURL(urlSuffix, withModel(request.Model)),
129132
withBody(&formBody),
130133
withContentType(builder.FormDataContentType()),
134+
withExtraHeaders(request.ExtraHeaders),
135+
withExtraQuery(request.ExtraQuery),
136+
withExtraBody(request.ExtraBody),
131137
)
132138
if err != nil {
133139
return AudioResponse{}, err

batch.go

+13-5
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,13 @@ type BatchRequestCounts struct {
9797
}
9898

9999
type CreateBatchRequest struct {
100-
InputFileID string `json:"input_file_id"`
101-
Endpoint BatchEndpoint `json:"endpoint"`
102-
CompletionWindow string `json:"completion_window"`
103-
Metadata map[string]any `json:"metadata"`
100+
InputFileID string `json:"input_file_id"`
101+
Endpoint BatchEndpoint `json:"endpoint"`
102+
CompletionWindow string `json:"completion_window"`
103+
Metadata map[string]any `json:"metadata"`
104+
ExtraHeaders map[string]string `json:"extra_headers"`
105+
ExtraQuery map[string]string `json:"extra_query"`
106+
ExtraBody map[string]any `json:"extra_body"`
104107
}
105108

106109
type BatchResponse struct {
@@ -117,7 +120,12 @@ func (c *Client) CreateBatch(
117120
request.CompletionWindow = "24h"
118121
}
119122

120-
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(batchesSuffix), withBody(request))
123+
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(batchesSuffix),
124+
withBody(request),
125+
withExtraHeaders(request.ExtraHeaders),
126+
withExtraQuery(request.ExtraQuery),
127+
withExtraBody(request.ExtraBody),
128+
)
121129
if err != nil {
122130
return
123131
}

chat.go

+7-37
Original file line numberDiff line numberDiff line change
@@ -265,43 +265,10 @@ type ChatCompletionRequest struct {
265265
Metadata map[string]string `json:"metadata,omitempty"`
266266
// Add additional JSON properties to the request
267267
ExtraBody map[string]any `json:"extra_body,omitempty"`
268-
}
269-
270-
func (m ChatCompletionRequest) MarshalJSON() ([]byte, error) {
271-
// Create a new anonymous struct that omits ExtraBody
272-
type Alias ChatCompletionRequest
273-
temp := struct {
274-
Alias
275-
ExtraBody map[string]any `json:"-"` // Omit ExtraBody from direct serialization
276-
}{
277-
Alias: Alias(m),
278-
ExtraBody: m.ExtraBody,
279-
}
280-
281-
// First marshal the main structure
282-
data, err := json.Marshal(temp)
283-
if err != nil {
284-
return nil, err
285-
}
286-
287-
// If there's no ExtraBody, return the marshaled data as is
288-
if len(m.ExtraBody) == 0 {
289-
return data, nil
290-
}
291-
292-
// Unmarshal into a map to modify the JSON structure
293-
var rawMap map[string]any
294-
if err := json.Unmarshal(data, &rawMap); err != nil {
295-
return nil, err
296-
}
297-
298-
// Add ExtraBody fields to the root level
299-
for k, v := range m.ExtraBody {
300-
rawMap[k] = v
301-
}
302-
303-
// Marshal the combined map back to JSON
304-
return json.Marshal(rawMap)
268+
// ExtraHeaders to add to the request
269+
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
270+
// ExtraQueryParams to add to the request
271+
ExtraQuery map[string]string `json:"extra_query,omitempty"`
305272
}
306273

307274
type StreamOptions struct {
@@ -442,6 +409,9 @@ func (c *Client) CreateChatCompletion(
442409
http.MethodPost,
443410
c.fullURL(urlSuffix, withModel(request.Model)),
444411
withBody(request),
412+
withExtraHeaders(request.ExtraHeaders),
413+
withExtraQuery(request.ExtraQuery),
414+
withExtraBody(request.ExtraBody),
445415
)
446416
if err != nil {
447417
return

client.go

+38-3
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,11 @@ func NewOrgClient(authToken, org string) *Client {
7272
}
7373

7474
type requestOptions struct {
75-
body any
76-
header http.Header
75+
body any
76+
header http.Header
77+
extraHeaders map[string]string
78+
extraQuery map[string]string
79+
extraBody map[string]any
7780
}
7881

7982
type requestOption func(*requestOptions)
@@ -84,6 +87,30 @@ func withBody(body any) requestOption {
8487
}
8588
}
8689

90+
func withHeader(header http.Header) requestOption {
91+
return func(args *requestOptions) {
92+
args.header = header
93+
}
94+
}
95+
96+
func withExtraHeaders(extraHeaders map[string]string) requestOption {
97+
return func(args *requestOptions) {
98+
args.extraHeaders = extraHeaders
99+
}
100+
}
101+
102+
func withExtraQuery(extraQuery map[string]string) requestOption {
103+
return func(args *requestOptions) {
104+
args.extraQuery = extraQuery
105+
}
106+
}
107+
108+
func withExtraBody(extraBody map[string]any) requestOption {
109+
return func(args *requestOptions) {
110+
args.extraBody = extraBody
111+
}
112+
}
113+
87114
func withContentType(contentType string) requestOption {
88115
return func(args *requestOptions) {
89116
args.header.Set("Content-Type", contentType)
@@ -105,7 +132,15 @@ func (c *Client) newRequest(ctx context.Context, method, url string, setters ...
105132
for _, setter := range setters {
106133
setter(args)
107134
}
108-
req, err := c.requestBuilder.Build(ctx, method, url, args.body, args.header)
135+
req, err := c.requestBuilder.Build(ctx, &utils.Request{
136+
Method: method,
137+
URL: url,
138+
Body: args.body,
139+
Header: args.header,
140+
ExtraHeaders: args.extraHeaders,
141+
ExtraQuery: args.extraQuery,
142+
ExtraBody: args.extraBody,
143+
})
109144
if err != nil {
110145
return nil, err
111146
}

client_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"reflect"
1111
"testing"
1212

13+
utils "github.com/sashabaranov/go-openai/internal"
1314
"github.com/sashabaranov/go-openai/internal/test"
1415
"github.com/sashabaranov/go-openai/internal/test/checks"
1516
)
@@ -18,7 +19,7 @@ var errTestRequestBuilderFailed = errors.New("test request builder failed")
1819

1920
type failingRequestBuilder struct{}
2021

21-
func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any, _ http.Header) (*http.Request, error) {
22+
func (*failingRequestBuilder) Build(_ context.Context, _ *utils.Request) (*http.Request, error) {
2223
return nil, errTestRequestBuilderFailed
2324
}
2425

embeddings.go

+7-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,10 @@ type EmbeddingRequest struct {
159159
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"`
160160
// Dimensions The number of dimensions the resulting output embeddings should have.
161161
// Only supported in text-embedding-3 and later models.
162-
Dimensions int `json:"dimensions,omitempty"`
162+
Dimensions int `json:"dimensions,omitempty"`
163+
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
164+
ExtraQuery map[string]string `json:"extra_query,omitempty"`
165+
ExtraBody map[string]any `json:"extra_body,omitempty"`
163166
}
164167

165168
func (r EmbeddingRequest) Convert() EmbeddingRequest {
@@ -246,6 +249,9 @@ func (c *Client) CreateEmbeddings(
246249
http.MethodPost,
247250
c.fullURL("/embeddings", withModel(string(baseReq.Model))),
248251
withBody(baseReq),
252+
withExtraHeaders(baseReq.ExtraHeaders),
253+
withExtraQuery(baseReq.ExtraQuery),
254+
withExtraBody(baseReq.ExtraBody),
249255
)
250256
if err != nil {
251257
return

internal/marshaller.go

+5
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,15 @@ import (
66

77
type Marshaller interface {
88
Marshal(value any) ([]byte, error)
9+
Unmarshal(data []byte, value any) error
910
}
1011

1112
type JSONMarshaller struct{}
1213

1314
func (jm *JSONMarshaller) Marshal(value any) ([]byte, error) {
1415
return json.Marshal(value)
1516
}
17+
18+
func (jm *JSONMarshaller) Unmarshal(data []byte, value any) error {
19+
return json.Unmarshal(data, value)
20+
}

internal/request_builder.go

+55-11
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,25 @@ package openai
33
import (
44
"bytes"
55
"context"
6+
"fmt"
67
"io"
78
"net/http"
9+
"net/url"
10+
"strings"
811
)
912

13+
type Request struct {
14+
Method string
15+
URL string
16+
Body any
17+
Header http.Header
18+
ExtraBody map[string]any
19+
ExtraHeaders map[string]string
20+
ExtraQuery map[string]string
21+
}
22+
1023
type RequestBuilder interface {
11-
Build(ctx context.Context, method, url string, body any, header http.Header) (*http.Request, error)
24+
Build(ctx context.Context, request *Request) (*http.Request, error)
1225
}
1326

1427
type HTTPRequestBuilder struct {
@@ -23,30 +36,61 @@ func NewRequestBuilder() *HTTPRequestBuilder {
2336

2437
func (b *HTTPRequestBuilder) Build(
2538
ctx context.Context,
26-
method string,
27-
url string,
28-
body any,
29-
header http.Header,
39+
request *Request,
3040
) (req *http.Request, err error) {
3141
var bodyReader io.Reader
32-
if body != nil {
33-
if v, ok := body.(io.Reader); ok {
42+
if request.Body != nil {
43+
if v, ok := request.Body.(io.Reader); ok {
3444
bodyReader = v
3545
} else {
3646
var reqBytes []byte
37-
reqBytes, err = b.marshaller.Marshal(body)
47+
reqBytes, err = b.marshaller.Marshal(request.Body)
3848
if err != nil {
3949
return
4050
}
51+
52+
if request.ExtraBody != nil {
53+
rawMap := make(map[string]any)
54+
err = b.marshaller.Unmarshal(reqBytes, &rawMap)
55+
if err != nil {
56+
return
57+
}
58+
59+
for k, v := range request.ExtraBody {
60+
rawMap[k] = v
61+
}
62+
reqBytes, err = b.marshaller.Marshal(rawMap)
63+
if err != nil {
64+
return
65+
}
66+
}
67+
4168
bodyReader = bytes.NewBuffer(reqBytes)
4269
}
4370
}
44-
req, err = http.NewRequestWithContext(ctx, method, url, bodyReader)
71+
72+
requestUrl := request.URL
73+
if request.ExtraQuery != nil {
74+
for k, v := range request.ExtraQuery {
75+
requestUrl = fmt.Sprintf("%s&%s=%s", requestUrl, k, url.QueryEscape(v))
76+
}
77+
}
78+
79+
requestUrl = strings.TrimSuffix(requestUrl, "&")
80+
81+
req, err = http.NewRequestWithContext(ctx, request.Method, requestUrl, bodyReader)
4582
if err != nil {
4683
return
4784
}
48-
if header != nil {
49-
req.Header = header
85+
if request.Header != nil {
86+
req.Header = request.Header
87+
} else {
88+
req.Header = make(http.Header)
5089
}
90+
91+
for k, v := range request.ExtraHeaders {
92+
req.Header.Set(k, v)
93+
}
94+
5195
return
5296
}

0 commit comments

Comments
 (0)