diff --git a/assistant.go b/assistant.go index 8aab5bcf..3a23fcab 100644 --- a/assistant.go +++ b/assistant.go @@ -76,6 +76,9 @@ type AssistantRequest struct { ResponseFormat any `json:"response_format,omitempty"` Temperature *float32 `json:"temperature,omitempty"` TopP *float32 `json:"top_p,omitempty"` + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` + ExtraQuery map[string]string `json:"extra_query,omitempty"` + ExtraBody map[string]any `json:"extra_body,omitempty"` } // MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases @@ -137,7 +140,11 @@ type AssistantFilesList struct { // CreateAssistant creates a new assistant. func (c *Client) CreateAssistant(ctx context.Context, request AssistantRequest) (response Assistant, err error) { req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request), - withBetaAssistantVersion(c.config.AssistantVersion)) + withBetaAssistantVersion(c.config.AssistantVersion), + withExtraHeaders(request.ExtraHeaders), + withExtraQuery(request.ExtraQuery), + withExtraBody(request.ExtraBody), + ) if err != nil { return } diff --git a/audio.go b/audio.go index f321f93d..36faa574 100644 --- a/audio.go +++ b/audio.go @@ -49,6 +49,9 @@ type AudioRequest struct { Language string // Only for transcription. Format AudioResponseFormat TimestampGranularities []TranscriptionTimestampGranularity // Only for transcription. + ExtraHeaders map[string]string + ExtraQuery map[string]string + ExtraBody map[string]any } // AudioResponse represents a response structure for audio API. @@ -128,6 +131,9 @@ func (c *Client) callAudioAPI( c.fullURL(urlSuffix, withModel(request.Model)), withBody(&formBody), withContentType(builder.FormDataContentType()), + withExtraHeaders(request.ExtraHeaders), + withExtraQuery(request.ExtraQuery), + withExtraBody(request.ExtraBody), ) if err != nil { return AudioResponse{}, err diff --git a/batch.go b/batch.go index 3c1a9d0d..a155f4ed 100644 --- a/batch.go +++ b/batch.go @@ -97,10 +97,13 @@ type BatchRequestCounts struct { } type CreateBatchRequest struct { - InputFileID string `json:"input_file_id"` - Endpoint BatchEndpoint `json:"endpoint"` - CompletionWindow string `json:"completion_window"` - Metadata map[string]any `json:"metadata"` + InputFileID string `json:"input_file_id"` + Endpoint BatchEndpoint `json:"endpoint"` + CompletionWindow string `json:"completion_window"` + Metadata map[string]any `json:"metadata"` + ExtraHeaders map[string]string `json:"extra_headers"` + ExtraQuery map[string]string `json:"extra_query"` + ExtraBody map[string]any `json:"extra_body"` } type BatchResponse struct { @@ -117,7 +120,12 @@ func (c *Client) CreateBatch( request.CompletionWindow = "24h" } - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(batchesSuffix), withBody(request)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(batchesSuffix), + withBody(request), + withExtraHeaders(request.ExtraHeaders), + withExtraQuery(request.ExtraQuery), + withExtraBody(request.ExtraBody), + ) if err != nil { return } diff --git a/chat.go b/chat.go index 995860c4..98b93c4c 100644 --- a/chat.go +++ b/chat.go @@ -263,6 +263,12 @@ type ChatCompletionRequest struct { ReasoningEffort string `json:"reasoning_effort,omitempty"` // Metadata to store with the completion. Metadata map[string]string `json:"metadata,omitempty"` + // Add additional JSON properties to the request + ExtraBody map[string]any `json:"extra_body,omitempty"` + // ExtraHeaders to add to the request + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` + // ExtraQueryParams to add to the request + ExtraQuery map[string]string `json:"extra_query,omitempty"` } type StreamOptions struct { @@ -403,6 +409,9 @@ func (c *Client) CreateChatCompletion( http.MethodPost, c.fullURL(urlSuffix, withModel(request.Model)), withBody(request), + withExtraHeaders(request.ExtraHeaders), + withExtraQuery(request.ExtraQuery), + withExtraBody(request.ExtraBody), ) if err != nil { return diff --git a/client.go b/client.go index cef37534..9df18a06 100644 --- a/client.go +++ b/client.go @@ -72,8 +72,11 @@ func NewOrgClient(authToken, org string) *Client { } type requestOptions struct { - body any - header http.Header + body any + header http.Header + extraHeaders map[string]string + extraQuery map[string]string + extraBody map[string]any } type requestOption func(*requestOptions) @@ -84,6 +87,30 @@ func withBody(body any) requestOption { } } +func withHeader(header http.Header) requestOption { + return func(args *requestOptions) { + args.header = header + } +} + +func withExtraHeaders(extraHeaders map[string]string) requestOption { + return func(args *requestOptions) { + args.extraHeaders = extraHeaders + } +} + +func withExtraQuery(extraQuery map[string]string) requestOption { + return func(args *requestOptions) { + args.extraQuery = extraQuery + } +} + +func withExtraBody(extraBody map[string]any) requestOption { + return func(args *requestOptions) { + args.extraBody = extraBody + } +} + func withContentType(contentType string) requestOption { return func(args *requestOptions) { args.header.Set("Content-Type", contentType) @@ -105,7 +132,15 @@ func (c *Client) newRequest(ctx context.Context, method, url string, setters ... for _, setter := range setters { setter(args) } - req, err := c.requestBuilder.Build(ctx, method, url, args.body, args.header) + req, err := c.requestBuilder.Build(ctx, &utils.Request{ + Method: method, + URL: url, + Body: args.body, + Header: args.header, + ExtraHeaders: args.extraHeaders, + ExtraQuery: args.extraQuery, + ExtraBody: args.extraBody, + }) if err != nil { return nil, err } diff --git a/client_test.go b/client_test.go index 32197144..d1ac1dae 100644 --- a/client_test.go +++ b/client_test.go @@ -10,6 +10,7 @@ import ( "reflect" "testing" + utils "github.com/sashabaranov/go-openai/internal" "github.com/sashabaranov/go-openai/internal/test" "github.com/sashabaranov/go-openai/internal/test/checks" ) @@ -18,7 +19,7 @@ var errTestRequestBuilderFailed = errors.New("test request builder failed") type failingRequestBuilder struct{} -func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any, _ http.Header) (*http.Request, error) { +func (*failingRequestBuilder) Build(_ context.Context, _ *utils.Request) (*http.Request, error) { return nil, errTestRequestBuilderFailed } diff --git a/embeddings.go b/embeddings.go index 4a0e682d..b92eb610 100644 --- a/embeddings.go +++ b/embeddings.go @@ -159,7 +159,10 @@ type EmbeddingRequest struct { EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"` // Dimensions The number of dimensions the resulting output embeddings should have. // Only supported in text-embedding-3 and later models. - Dimensions int `json:"dimensions,omitempty"` + Dimensions int `json:"dimensions,omitempty"` + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` + ExtraQuery map[string]string `json:"extra_query,omitempty"` + ExtraBody map[string]any `json:"extra_body,omitempty"` } func (r EmbeddingRequest) Convert() EmbeddingRequest { @@ -246,6 +249,9 @@ func (c *Client) CreateEmbeddings( http.MethodPost, c.fullURL("/embeddings", withModel(string(baseReq.Model))), withBody(baseReq), + withExtraHeaders(baseReq.ExtraHeaders), + withExtraQuery(baseReq.ExtraQuery), + withExtraBody(baseReq.ExtraBody), ) if err != nil { return diff --git a/internal/marshaller.go b/internal/marshaller.go index 223a4dc1..0ec2e0a5 100644 --- a/internal/marshaller.go +++ b/internal/marshaller.go @@ -6,6 +6,7 @@ import ( type Marshaller interface { Marshal(value any) ([]byte, error) + Unmarshal(data []byte, value any) error } type JSONMarshaller struct{} @@ -13,3 +14,7 @@ type JSONMarshaller struct{} func (jm *JSONMarshaller) Marshal(value any) ([]byte, error) { return json.Marshal(value) } + +func (jm *JSONMarshaller) Unmarshal(data []byte, value any) error { + return json.Unmarshal(data, value) +} diff --git a/internal/request_builder.go b/internal/request_builder.go index 5699f6b1..8e2d2075 100644 --- a/internal/request_builder.go +++ b/internal/request_builder.go @@ -3,12 +3,25 @@ package openai import ( "bytes" "context" + "fmt" "io" "net/http" + "net/url" + "strings" ) +type Request struct { + Method string + URL string + Body any + Header http.Header + ExtraBody map[string]any + ExtraHeaders map[string]string + ExtraQuery map[string]string +} + type RequestBuilder interface { - Build(ctx context.Context, method, url string, body any, header http.Header) (*http.Request, error) + Build(ctx context.Context, request *Request) (*http.Request, error) } type HTTPRequestBuilder struct { @@ -23,30 +36,61 @@ func NewRequestBuilder() *HTTPRequestBuilder { func (b *HTTPRequestBuilder) Build( ctx context.Context, - method string, - url string, - body any, - header http.Header, + request *Request, ) (req *http.Request, err error) { var bodyReader io.Reader - if body != nil { - if v, ok := body.(io.Reader); ok { + if request.Body != nil { + if v, ok := request.Body.(io.Reader); ok { bodyReader = v } else { var reqBytes []byte - reqBytes, err = b.marshaller.Marshal(body) + reqBytes, err = b.marshaller.Marshal(request.Body) if err != nil { return } + + if request.ExtraBody != nil { + rawMap := make(map[string]any) + err = b.marshaller.Unmarshal(reqBytes, &rawMap) + if err != nil { + return + } + + for k, v := range request.ExtraBody { + rawMap[k] = v + } + reqBytes, err = b.marshaller.Marshal(rawMap) + if err != nil { + return + } + } + bodyReader = bytes.NewBuffer(reqBytes) } } - req, err = http.NewRequestWithContext(ctx, method, url, bodyReader) + + requestUrl := request.URL + if request.ExtraQuery != nil { + for k, v := range request.ExtraQuery { + requestUrl = fmt.Sprintf("%s&%s=%s", requestUrl, k, url.QueryEscape(v)) + } + } + + requestUrl = strings.TrimSuffix(requestUrl, "&") + + req, err = http.NewRequestWithContext(ctx, request.Method, requestUrl, bodyReader) if err != nil { return } - if header != nil { - req.Header = header + if request.Header != nil { + req.Header = request.Header + } else { + req.Header = make(http.Header) } + + for k, v := range request.ExtraHeaders { + req.Header.Set(k, v) + } + return } diff --git a/internal/request_builder_test.go b/internal/request_builder_test.go index e26022a6..15905cdd 100644 --- a/internal/request_builder_test.go +++ b/internal/request_builder_test.go @@ -17,12 +17,23 @@ func (*failingMarshaller) Marshal(_ any) ([]byte, error) { return []byte{}, errTestMarshallerFailed } +func (*failingMarshaller) Unmarshal(_ []byte, _ any) error { + return errTestMarshallerFailed +} + func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) { builder := HTTPRequestBuilder{ marshaller: &failingMarshaller{}, } - _, err := builder.Build(context.Background(), "", "", struct{}{}, nil) + _, err := builder.Build(context.Background(), &Request{ + Method: http.MethodGet, + URL: "/foo", + Body: struct{}{}, + Header: nil, + ExtraHeaders: nil, + ExtraQuery: nil, + }) if !errors.Is(err, errTestMarshallerFailed) { t.Fatalf("Did not return error when marshaller failed: %v", err) } @@ -38,7 +49,14 @@ func TestRequestBuilderReturnsRequest(t *testing.T) { reqBytes, _ = b.marshaller.Marshal(request) want, _ = http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(reqBytes)) ) - got, _ := b.Build(ctx, method, url, request, nil) + got, _ := b.Build(ctx, &Request{ + Method: method, + URL: url, + Body: request, + Header: nil, + ExtraHeaders: nil, + ExtraQuery: nil, + }) if !reflect.DeepEqual(got.Body, want.Body) || !reflect.DeepEqual(got.URL, want.URL) || !reflect.DeepEqual(got.Method, want.Method) { @@ -54,7 +72,14 @@ func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) { want, _ = http.NewRequestWithContext(ctx, method, url, nil) ) b := NewRequestBuilder() - got, _ := b.Build(ctx, method, url, nil, nil) + got, _ := b.Build(ctx, &Request{ + Method: method, + URL: url, + Body: nil, + Header: nil, + ExtraHeaders: nil, + ExtraQuery: nil, + }) if !reflect.DeepEqual(got, want) { t.Errorf("Build() got = %v, want %v", got, want) } diff --git a/moderation.go b/moderation.go index a0e09c0e..c824ae77 100644 --- a/moderation.go +++ b/moderation.go @@ -22,9 +22,7 @@ const ( ModerationText001 = "text-moderation-001" ) -var ( - ErrModerationInvalidModel = errors.New("this model is not supported with moderation, please use text-moderation-stable or text-moderation-latest instead") //nolint:lll -) +var ErrModerationInvalidModel = errors.New("this model is not supported with moderation, please use text-moderation-stable or text-moderation-latest instead") //nolint:lll var validModerationModel = map[string]struct{}{ ModerationOmniLatest: {}, @@ -35,8 +33,11 @@ var validModerationModel = map[string]struct{}{ // ModerationRequest represents a request structure for moderation API. type ModerationRequest struct { - Input string `json:"input,omitempty"` - Model string `json:"model,omitempty"` + Input string `json:"input,omitempty"` + Model string `json:"model,omitempty"` + ExtraHeaders map[string]string `json:"extra_headers,omitempty"` + ExtraQuery map[string]string `json:"extra_query,omitempty"` + ExtraBody map[string]any `json:"extra_body,omitempty"` } // Result represents one of possible moderation results. @@ -97,6 +98,9 @@ func (c *Client) Moderations(ctx context.Context, request ModerationRequest) (re http.MethodPost, c.fullURL("/moderations", withModel(request.Model)), withBody(&request), + withExtraHeaders(request.ExtraHeaders), + withExtraQuery(request.ExtraQuery), + withExtraBody(request.ExtraBody), ) if err != nil { return