Skip to content

feat: add ExtraBody field to ChatCompletionRequest for additional JSO… #961

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion assistant.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ type AssistantRequest struct {
ResponseFormat any `json:"response_format,omitempty"`
Temperature *float32 `json:"temperature,omitempty"`
TopP *float32 `json:"top_p,omitempty"`
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
ExtraQuery map[string]string `json:"extra_query,omitempty"`
ExtraBody map[string]any `json:"extra_body,omitempty"`
}

// MarshalJSON provides a custom marshaller for the assistant request to handle the API use cases
Expand Down Expand Up @@ -137,7 +140,11 @@ type AssistantFilesList struct {
// CreateAssistant creates a new assistant.
func (c *Client) CreateAssistant(ctx context.Context, request AssistantRequest) (response Assistant, err error) {
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(assistantsSuffix), withBody(request),
withBetaAssistantVersion(c.config.AssistantVersion))
withBetaAssistantVersion(c.config.AssistantVersion),
withExtraHeaders(request.ExtraHeaders),
withExtraQuery(request.ExtraQuery),
withExtraBody(request.ExtraBody),
)
if err != nil {
return
}
Expand Down
6 changes: 6 additions & 0 deletions audio.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ type AudioRequest struct {
Language string // Only for transcription.
Format AudioResponseFormat
TimestampGranularities []TranscriptionTimestampGranularity // Only for transcription.
ExtraHeaders map[string]string
ExtraQuery map[string]string
ExtraBody map[string]any
}

// AudioResponse represents a response structure for audio API.
Expand Down Expand Up @@ -128,6 +131,9 @@ func (c *Client) callAudioAPI(
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(&formBody),
withContentType(builder.FormDataContentType()),
withExtraHeaders(request.ExtraHeaders),
withExtraQuery(request.ExtraQuery),
withExtraBody(request.ExtraBody),
)
if err != nil {
return AudioResponse{}, err
Expand Down
18 changes: 13 additions & 5 deletions batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,13 @@ type BatchRequestCounts struct {
}

type CreateBatchRequest struct {
InputFileID string `json:"input_file_id"`
Endpoint BatchEndpoint `json:"endpoint"`
CompletionWindow string `json:"completion_window"`
Metadata map[string]any `json:"metadata"`
InputFileID string `json:"input_file_id"`
Endpoint BatchEndpoint `json:"endpoint"`
CompletionWindow string `json:"completion_window"`
Metadata map[string]any `json:"metadata"`
ExtraHeaders map[string]string `json:"extra_headers"`
ExtraQuery map[string]string `json:"extra_query"`
ExtraBody map[string]any `json:"extra_body"`
}

type BatchResponse struct {
Expand All @@ -117,7 +120,12 @@ func (c *Client) CreateBatch(
request.CompletionWindow = "24h"
}

req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(batchesSuffix), withBody(request))
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(batchesSuffix),
withBody(request),
withExtraHeaders(request.ExtraHeaders),
withExtraQuery(request.ExtraQuery),
withExtraBody(request.ExtraBody),
)
if err != nil {
return
}
Expand Down
9 changes: 9 additions & 0 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ type ChatCompletionRequest struct {
ReasoningEffort string `json:"reasoning_effort,omitempty"`
// Metadata to store with the completion.
Metadata map[string]string `json:"metadata,omitempty"`
// Add additional JSON properties to the request
ExtraBody map[string]any `json:"extra_body,omitempty"`
// ExtraHeaders to add to the request
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
// ExtraQueryParams to add to the request
ExtraQuery map[string]string `json:"extra_query,omitempty"`
}

type StreamOptions struct {
Expand Down Expand Up @@ -403,6 +409,9 @@ func (c *Client) CreateChatCompletion(
http.MethodPost,
c.fullURL(urlSuffix, withModel(request.Model)),
withBody(request),
withExtraHeaders(request.ExtraHeaders),
withExtraQuery(request.ExtraQuery),
withExtraBody(request.ExtraBody),
)
if err != nil {
return
Expand Down
41 changes: 38 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,11 @@
}

type requestOptions struct {
body any
header http.Header
body any
header http.Header
extraHeaders map[string]string
extraQuery map[string]string
extraBody map[string]any
}

type requestOption func(*requestOptions)
Expand All @@ -84,6 +87,30 @@
}
}

func withHeader(header http.Header) requestOption {

Check failure on line 90 in client.go

View workflow job for this annotation

GitHub Actions / Sanity check

func `withHeader` is unused (unused)
return func(args *requestOptions) {
args.header = header
}
}

func withExtraHeaders(extraHeaders map[string]string) requestOption {
return func(args *requestOptions) {
args.extraHeaders = extraHeaders
}
}

func withExtraQuery(extraQuery map[string]string) requestOption {
return func(args *requestOptions) {
args.extraQuery = extraQuery
}
}

func withExtraBody(extraBody map[string]any) requestOption {
return func(args *requestOptions) {
args.extraBody = extraBody
}
}

func withContentType(contentType string) requestOption {
return func(args *requestOptions) {
args.header.Set("Content-Type", contentType)
Expand All @@ -105,7 +132,15 @@
for _, setter := range setters {
setter(args)
}
req, err := c.requestBuilder.Build(ctx, method, url, args.body, args.header)
req, err := c.requestBuilder.Build(ctx, &utils.Request{
Method: method,
URL: url,
Body: args.body,
Header: args.header,
ExtraHeaders: args.extraHeaders,
ExtraQuery: args.extraQuery,
ExtraBody: args.extraBody,
})
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"reflect"
"testing"

utils "github.com/sashabaranov/go-openai/internal"
"github.com/sashabaranov/go-openai/internal/test"
"github.com/sashabaranov/go-openai/internal/test/checks"
)
Expand All @@ -18,7 +19,7 @@ var errTestRequestBuilderFailed = errors.New("test request builder failed")

type failingRequestBuilder struct{}

func (*failingRequestBuilder) Build(_ context.Context, _, _ string, _ any, _ http.Header) (*http.Request, error) {
func (*failingRequestBuilder) Build(_ context.Context, _ *utils.Request) (*http.Request, error) {
return nil, errTestRequestBuilderFailed
}

Expand Down
8 changes: 7 additions & 1 deletion embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ type EmbeddingRequest struct {
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format,omitempty"`
// Dimensions The number of dimensions the resulting output embeddings should have.
// Only supported in text-embedding-3 and later models.
Dimensions int `json:"dimensions,omitempty"`
Dimensions int `json:"dimensions,omitempty"`
ExtraHeaders map[string]string `json:"extra_headers,omitempty"`
ExtraQuery map[string]string `json:"extra_query,omitempty"`
ExtraBody map[string]any `json:"extra_body,omitempty"`
}

func (r EmbeddingRequest) Convert() EmbeddingRequest {
Expand Down Expand Up @@ -246,6 +249,9 @@ func (c *Client) CreateEmbeddings(
http.MethodPost,
c.fullURL("/embeddings", withModel(string(baseReq.Model))),
withBody(baseReq),
withExtraHeaders(baseReq.ExtraHeaders),
withExtraQuery(baseReq.ExtraQuery),
withExtraBody(baseReq.ExtraBody),
)
if err != nil {
return
Expand Down
5 changes: 5 additions & 0 deletions internal/marshaller.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@ import (

type Marshaller interface {
Marshal(value any) ([]byte, error)
Unmarshal(data []byte, value any) error
}

type JSONMarshaller struct{}

func (jm *JSONMarshaller) Marshal(value any) ([]byte, error) {
return json.Marshal(value)
}

func (jm *JSONMarshaller) Unmarshal(data []byte, value any) error {
return json.Unmarshal(data, value)
}
66 changes: 55 additions & 11 deletions internal/request_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,25 @@
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"net/url"
"strings"
)

type Request struct {
Method string
URL string
Body any
Header http.Header
ExtraBody map[string]any
ExtraHeaders map[string]string
ExtraQuery map[string]string
}

type RequestBuilder interface {
Build(ctx context.Context, method, url string, body any, header http.Header) (*http.Request, error)
Build(ctx context.Context, request *Request) (*http.Request, error)
}

type HTTPRequestBuilder struct {
Expand All @@ -21,32 +34,63 @@
}
}

func (b *HTTPRequestBuilder) Build(

Check failure on line 37 in internal/request_builder.go

View workflow job for this annotation

GitHub Actions / Sanity check

cognitive complexity 24 of func `(*HTTPRequestBuilder).Build` is high (> 20) (gocognit)
ctx context.Context,
method string,
url string,
body any,
header http.Header,
request *Request,
) (req *http.Request, err error) {
var bodyReader io.Reader
if body != nil {
if v, ok := body.(io.Reader); ok {
if request.Body != nil {

Check failure on line 42 in internal/request_builder.go

View workflow job for this annotation

GitHub Actions / Sanity check

`if request.Body != nil` has complex nested blocks (complexity: 12) (nestif)
if v, ok := request.Body.(io.Reader); ok {
bodyReader = v
} else {
var reqBytes []byte
reqBytes, err = b.marshaller.Marshal(body)
reqBytes, err = b.marshaller.Marshal(request.Body)
if err != nil {
return
}

if request.ExtraBody != nil {
rawMap := make(map[string]any)
err = b.marshaller.Unmarshal(reqBytes, &rawMap)
if err != nil {
return
}

for k, v := range request.ExtraBody {
rawMap[k] = v
}
reqBytes, err = b.marshaller.Marshal(rawMap)
if err != nil {
return
}
}

bodyReader = bytes.NewBuffer(reqBytes)
}
}
req, err = http.NewRequestWithContext(ctx, method, url, bodyReader)

requestUrl := request.URL

Check failure on line 72 in internal/request_builder.go

View workflow job for this annotation

GitHub Actions / Sanity check

var-naming: var requestUrl should be requestURL (revive)
if request.ExtraQuery != nil {
for k, v := range request.ExtraQuery {
requestUrl = fmt.Sprintf("%s&%s=%s", requestUrl, k, url.QueryEscape(v))
}
}

requestUrl = strings.TrimSuffix(requestUrl, "&")

req, err = http.NewRequestWithContext(ctx, request.Method, requestUrl, bodyReader)
if err != nil {
return
}
if header != nil {
req.Header = header
if request.Header != nil {
req.Header = request.Header
} else {
req.Header = make(http.Header)
}

for k, v := range request.ExtraHeaders {
req.Header.Set(k, v)
}

return
}
31 changes: 28 additions & 3 deletions internal/request_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,23 @@ func (*failingMarshaller) Marshal(_ any) ([]byte, error) {
return []byte{}, errTestMarshallerFailed
}

func (*failingMarshaller) Unmarshal(_ []byte, _ any) error {
return errTestMarshallerFailed
}

func TestRequestBuilderReturnsMarshallerErrors(t *testing.T) {
builder := HTTPRequestBuilder{
marshaller: &failingMarshaller{},
}

_, err := builder.Build(context.Background(), "", "", struct{}{}, nil)
_, err := builder.Build(context.Background(), &Request{
Method: http.MethodGet,
URL: "/foo",
Body: struct{}{},
Header: nil,
ExtraHeaders: nil,
ExtraQuery: nil,
})
if !errors.Is(err, errTestMarshallerFailed) {
t.Fatalf("Did not return error when marshaller failed: %v", err)
}
Expand All @@ -38,7 +49,14 @@ func TestRequestBuilderReturnsRequest(t *testing.T) {
reqBytes, _ = b.marshaller.Marshal(request)
want, _ = http.NewRequestWithContext(ctx, method, url, bytes.NewBuffer(reqBytes))
)
got, _ := b.Build(ctx, method, url, request, nil)
got, _ := b.Build(ctx, &Request{
Method: method,
URL: url,
Body: request,
Header: nil,
ExtraHeaders: nil,
ExtraQuery: nil,
})
if !reflect.DeepEqual(got.Body, want.Body) ||
!reflect.DeepEqual(got.URL, want.URL) ||
!reflect.DeepEqual(got.Method, want.Method) {
Expand All @@ -54,7 +72,14 @@ func TestRequestBuilderReturnsRequestWhenRequestOfArgsIsNil(t *testing.T) {
want, _ = http.NewRequestWithContext(ctx, method, url, nil)
)
b := NewRequestBuilder()
got, _ := b.Build(ctx, method, url, nil, nil)
got, _ := b.Build(ctx, &Request{
Method: method,
URL: url,
Body: nil,
Header: nil,
ExtraHeaders: nil,
ExtraQuery: nil,
})
if !reflect.DeepEqual(got, want) {
t.Errorf("Build() got = %v, want %v", got, want)
}
Expand Down
Loading
Loading