Skip to content

Commit dd7f582

Browse files
authored
fix: fullURL endpoint generation (sashabaranov#817)
1 parent 2c6889e commit dd7f582

14 files changed

+244
-51
lines changed

api_internal_test.go

+19-5
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,15 @@ func TestAzureFullURL(t *testing.T) {
112112
Name string
113113
BaseURL string
114114
AzureModelMapper map[string]string
115+
Suffix string
115116
Model string
116117
Expect string
117118
}{
118119
{
119120
"AzureBaseURLWithSlashAutoStrip",
120121
"https://httpbin.org/",
121122
nil,
123+
"/chat/completions",
122124
"chatgpt-demo",
123125
"https://httpbin.org/" +
124126
"openai/deployments/chatgpt-demo" +
@@ -128,19 +130,28 @@ func TestAzureFullURL(t *testing.T) {
128130
"AzureBaseURLWithoutSlashOK",
129131
"https://httpbin.org",
130132
nil,
133+
"/chat/completions",
131134
"chatgpt-demo",
132135
"https://httpbin.org/" +
133136
"openai/deployments/chatgpt-demo" +
134137
"/chat/completions?api-version=2023-05-15",
135138
},
139+
{
140+
"",
141+
"https://httpbin.org",
142+
nil,
143+
"/assistants?limit=10",
144+
"chatgpt-demo",
145+
"https://httpbin.org/openai/assistants?api-version=2023-05-15&limit=10",
146+
},
136147
}
137148

138149
for _, c := range cases {
139150
t.Run(c.Name, func(t *testing.T) {
140151
az := DefaultAzureConfig("dummy", c.BaseURL)
141152
cli := NewClientWithConfig(az)
142153
// /openai/deployments/{engine}/chat/completions?api-version={api_version}
143-
actual := cli.fullURL("/chat/completions", c.Model)
154+
actual := cli.fullURL(c.Suffix, withModel(c.Model))
144155
if actual != c.Expect {
145156
t.Errorf("Expected %s, got %s", c.Expect, actual)
146157
}
@@ -153,19 +164,22 @@ func TestCloudflareAzureFullURL(t *testing.T) {
153164
cases := []struct {
154165
Name string
155166
BaseURL string
167+
Suffix string
156168
Expect string
157169
}{
158170
{
159171
"CloudflareAzureBaseURLWithSlashAutoStrip",
160172
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/",
173+
"/chat/completions",
161174
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" +
162175
"chat/completions?api-version=2023-05-15",
163176
},
164177
{
165-
"CloudflareAzureBaseURLWithoutSlashOK",
178+
"",
166179
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo",
167-
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo/" +
168-
"chat/completions?api-version=2023-05-15",
180+
"/assistants?limit=10",
181+
"https://gateway.ai.cloudflare.com/v1/dnekeim2i39dmm4mldemakiem3i4mkw3/demo/azure-openai/resource/chatgpt-demo" +
182+
"/assistants?api-version=2023-05-15&limit=10",
169183
},
170184
}
171185

@@ -176,7 +190,7 @@ func TestCloudflareAzureFullURL(t *testing.T) {
176190

177191
cli := NewClientWithConfig(az)
178192

179-
actual := cli.fullURL("/chat/completions")
193+
actual := cli.fullURL(c.Suffix)
180194
if actual != c.Expect {
181195
t.Errorf("Expected %s, got %s", c.Expect, actual)
182196
}

audio.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,13 @@ func (c *Client) callAudioAPI(
122122
}
123123

124124
urlSuffix := fmt.Sprintf("/audio/%s", endpointSuffix)
125-
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model),
126-
withBody(&formBody), withContentType(builder.FormDataContentType()))
125+
req, err := c.newRequest(
126+
ctx,
127+
http.MethodPost,
128+
c.fullURL(urlSuffix, withModel(request.Model)),
129+
withBody(&formBody),
130+
withContentType(builder.FormDataContentType()),
131+
)
127132
if err != nil {
128133
return AudioResponse{}, err
129134
}

chat.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,12 @@ func (c *Client) CreateChatCompletion(
358358
return
359359
}
360360

361-
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
361+
req, err := c.newRequest(
362+
ctx,
363+
http.MethodPost,
364+
c.fullURL(urlSuffix, withModel(request.Model)),
365+
withBody(request),
366+
)
362367
if err != nil {
363368
return
364369
}

chat_stream.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,12 @@ func (c *Client) CreateChatCompletionStream(
6060
}
6161

6262
request.Stream = true
63-
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
63+
req, err := c.newRequest(
64+
ctx,
65+
http.MethodPost,
66+
c.fullURL(urlSuffix, withModel(request.Model)),
67+
withBody(request),
68+
)
6469
if err != nil {
6570
return nil, err
6671
}

client.go

+54-30
Original file line numberDiff line numberDiff line change
@@ -222,42 +222,66 @@ func decodeString(body io.Reader, output *string) error {
222222
return nil
223223
}
224224

225+
type fullURLOptions struct {
226+
model string
227+
}
228+
229+
type fullURLOption func(*fullURLOptions)
230+
231+
func withModel(model string) fullURLOption {
232+
return func(args *fullURLOptions) {
233+
args.model = model
234+
}
235+
}
236+
237+
var azureDeploymentsEndpoints = []string{
238+
"/completions",
239+
"/embeddings",
240+
"/chat/completions",
241+
"/audio/transcriptions",
242+
"/audio/translations",
243+
"/audio/speech",
244+
"/images/generations",
245+
}
246+
225247
// fullURL returns full URL for request.
226-
// args[0] is model name, if API type is Azure, model name is required to get deployment name.
227-
func (c *Client) fullURL(suffix string, args ...any) string {
228-
// /openai/deployments/{model}/chat/completions?api-version={api_version}
248+
func (c *Client) fullURL(suffix string, setters ...fullURLOption) string {
249+
baseURL := strings.TrimRight(c.config.BaseURL, "/")
250+
args := fullURLOptions{}
251+
for _, setter := range setters {
252+
setter(&args)
253+
}
254+
229255
if c.config.APIType == APITypeAzure || c.config.APIType == APITypeAzureAD {
230-
baseURL := c.config.BaseURL
231-
baseURL = strings.TrimRight(baseURL, "/")
232-
parseURL, _ := url.Parse(baseURL)
233-
query := parseURL.Query()
234-
query.Add("api-version", c.config.APIVersion)
235-
// if suffix is /models change to {endpoint}/openai/models?api-version=2022-12-01
236-
// https://learn.microsoft.com/en-us/rest/api/cognitiveservices/azureopenaistable/models/list?tabs=HTTP
237-
if containsSubstr([]string{"/models", "/assistants", "/threads", "/files"}, suffix) {
238-
return fmt.Sprintf("%s/%s%s?%s", baseURL, azureAPIPrefix, suffix, query.Encode())
239-
}
240-
azureDeploymentName := "UNKNOWN"
241-
if len(args) > 0 {
242-
model, ok := args[0].(string)
243-
if ok {
244-
azureDeploymentName = c.config.GetAzureDeploymentByModel(model)
245-
}
246-
}
247-
return fmt.Sprintf("%s/%s/%s/%s%s?%s",
248-
baseURL, azureAPIPrefix, azureDeploymentsPrefix,
249-
azureDeploymentName, suffix, query.Encode(),
250-
)
256+
baseURL = c.baseURLWithAzureDeployment(baseURL, suffix, args.model)
257+
}
258+
259+
if c.config.APIVersion != "" {
260+
suffix = c.suffixWithAPIVersion(suffix)
251261
}
262+
return fmt.Sprintf("%s%s", baseURL, suffix)
263+
}
252264

253-
// https://developers.cloudflare.com/ai-gateway/providers/azureopenai/
254-
if c.config.APIType == APITypeCloudflareAzure {
255-
baseURL := c.config.BaseURL
256-
baseURL = strings.TrimRight(baseURL, "/")
257-
return fmt.Sprintf("%s%s?api-version=%s", baseURL, suffix, c.config.APIVersion)
265+
func (c *Client) suffixWithAPIVersion(suffix string) string {
266+
parsedSuffix, err := url.Parse(suffix)
267+
if err != nil {
268+
panic("failed to parse url suffix")
258269
}
270+
query := parsedSuffix.Query()
271+
query.Add("api-version", c.config.APIVersion)
272+
return fmt.Sprintf("%s?%s", parsedSuffix.Path, query.Encode())
273+
}
259274

260-
return fmt.Sprintf("%s%s", c.config.BaseURL, suffix)
275+
func (c *Client) baseURLWithAzureDeployment(baseURL, suffix, model string) (newBaseURL string) {
276+
baseURL = fmt.Sprintf("%s/%s", strings.TrimRight(baseURL, "/"), azureAPIPrefix)
277+
if containsSubstr(azureDeploymentsEndpoints, suffix) {
278+
azureDeploymentName := c.config.GetAzureDeploymentByModel(model)
279+
if azureDeploymentName == "" {
280+
azureDeploymentName = "UNKNOWN"
281+
}
282+
baseURL = fmt.Sprintf("%s/%s/%s", baseURL, azureDeploymentsPrefix, azureDeploymentName)
283+
}
284+
return baseURL
261285
}
262286

263287
func (c *Client) handleErrorResp(resp *http.Response) error {

client_test.go

+96
Original file line numberDiff line numberDiff line change
@@ -431,3 +431,99 @@ func TestClientReturnsRequestBuilderErrorsAddition(t *testing.T) {
431431
t.Fatalf("Did not return error when request builder failed: %v", err)
432432
}
433433
}
434+
435+
func TestClient_suffixWithAPIVersion(t *testing.T) {
436+
type fields struct {
437+
apiVersion string
438+
}
439+
type args struct {
440+
suffix string
441+
}
442+
tests := []struct {
443+
name string
444+
fields fields
445+
args args
446+
want string
447+
wantPanic string
448+
}{
449+
{
450+
"",
451+
fields{apiVersion: "2023-05"},
452+
args{suffix: "/assistants"},
453+
"/assistants?api-version=2023-05",
454+
"",
455+
},
456+
{
457+
"",
458+
fields{apiVersion: "2023-05"},
459+
args{suffix: "/assistants?limit=5"},
460+
"/assistants?api-version=2023-05&limit=5",
461+
"",
462+
},
463+
{
464+
"",
465+
fields{apiVersion: "2023-05"},
466+
args{suffix: "123:assistants?limit=5"},
467+
"/assistants?api-version=2023-05&limit=5",
468+
"failed to parse url suffix",
469+
},
470+
}
471+
for _, tt := range tests {
472+
t.Run(tt.name, func(t *testing.T) {
473+
c := &Client{
474+
config: ClientConfig{APIVersion: tt.fields.apiVersion},
475+
}
476+
defer func() {
477+
if r := recover(); r != nil {
478+
if r.(string) != tt.wantPanic {
479+
t.Errorf("suffixWithAPIVersion() = %v, want %v", r, tt.wantPanic)
480+
}
481+
}
482+
}()
483+
if got := c.suffixWithAPIVersion(tt.args.suffix); got != tt.want {
484+
t.Errorf("suffixWithAPIVersion() = %v, want %v", got, tt.want)
485+
}
486+
})
487+
}
488+
}
489+
490+
func TestClient_baseURLWithAzureDeployment(t *testing.T) {
491+
type args struct {
492+
baseURL string
493+
suffix string
494+
model string
495+
}
496+
tests := []struct {
497+
name string
498+
args args
499+
wantNewBaseURL string
500+
}{
501+
{
502+
"",
503+
args{baseURL: "https://test.openai.azure.com/", suffix: assistantsSuffix, model: GPT4oMini},
504+
"https://test.openai.azure.com/openai",
505+
},
506+
{
507+
"",
508+
args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: GPT4oMini},
509+
"https://test.openai.azure.com/openai/deployments/gpt-4o-mini",
510+
},
511+
{
512+
"",
513+
args{baseURL: "https://test.openai.azure.com/", suffix: chatCompletionsSuffix, model: ""},
514+
"https://test.openai.azure.com/openai/deployments/UNKNOWN",
515+
},
516+
}
517+
client := NewClient("")
518+
for _, tt := range tests {
519+
t.Run(tt.name, func(t *testing.T) {
520+
if gotNewBaseURL := client.baseURLWithAzureDeployment(
521+
tt.args.baseURL,
522+
tt.args.suffix,
523+
tt.args.model,
524+
); gotNewBaseURL != tt.wantNewBaseURL {
525+
t.Errorf("baseURLWithAzureDeployment() = %v, want %v", gotNewBaseURL, tt.wantNewBaseURL)
526+
}
527+
})
528+
}
529+
}

completion.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,12 @@ func (c *Client) CreateCompletion(
213213
return
214214
}
215215

216-
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix, request.Model), withBody(request))
216+
req, err := c.newRequest(
217+
ctx,
218+
http.MethodPost,
219+
c.fullURL(urlSuffix, withModel(request.Model)),
220+
withBody(request),
221+
)
217222
if err != nil {
218223
return
219224
}

edits.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@ will need to migrate to GPT-3.5 Turbo by January 4, 2024.
3838
You can use CreateChatCompletion or CreateChatCompletionStream instead.
3939
*/
4040
func (c *Client) Edits(ctx context.Context, request EditsRequest) (response EditsResponse, err error) {
41-
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/edits", fmt.Sprint(request.Model)), withBody(request))
41+
req, err := c.newRequest(
42+
ctx,
43+
http.MethodPost,
44+
c.fullURL("/edits", withModel(fmt.Sprint(request.Model))),
45+
withBody(request),
46+
)
4247
if err != nil {
4348
return
4449
}

embeddings.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,12 @@ func (c *Client) CreateEmbeddings(
241241
conv EmbeddingRequestConverter,
242242
) (res EmbeddingResponse, err error) {
243243
baseReq := conv.Convert()
244-
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", string(baseReq.Model)), withBody(baseReq))
244+
req, err := c.newRequest(
245+
ctx,
246+
http.MethodPost,
247+
c.fullURL("/embeddings", withModel(string(baseReq.Model))),
248+
withBody(baseReq),
249+
)
245250
if err != nil {
246251
return
247252
}

example_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ func ExampleClient_CreateChatCompletionStream() {
7373
return
7474
}
7575

76-
fmt.Printf(response.Choices[0].Delta.Content)
76+
fmt.Println(response.Choices[0].Delta.Content)
7777
}
7878
}
7979

0 commit comments

Comments
 (0)