Skip to content

Commit cf8af37

Browse files
feat(client): support unions in query and forms (#347)
1 parent 4e39609 commit cf8af37

26 files changed

+367
-144
lines changed

Diff for: batch.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ type BatchListParams struct {
344344
func (f BatchListParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
345345

346346
// URLQuery serializes [BatchListParams]'s query parameters as `url.Values`.
347-
func (r BatchListParams) URLQuery() (v url.Values) {
347+
func (r BatchListParams) URLQuery() (v url.Values, err error) {
348348
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
349349
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
350350
NestedFormat: apiquery.NestedQueryFormatBrackets,

Diff for: betaassistant.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -2257,7 +2257,7 @@ func (f BetaAssistantListParams) IsPresent() bool { return !param.IsOmitted(f) &
22572257

22582258
// URLQuery serializes [BetaAssistantListParams]'s query parameters as
22592259
// `url.Values`.
2260-
func (r BetaAssistantListParams) URLQuery() (v url.Values) {
2260+
func (r BetaAssistantListParams) URLQuery() (v url.Values, err error) {
22612261
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
22622262
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
22632263
NestedFormat: apiquery.NestedQueryFormatBrackets,

Diff for: betathreadmessage.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1700,7 +1700,7 @@ func (f BetaThreadMessageListParams) IsPresent() bool { return !param.IsOmitted(
17001700

17011701
// URLQuery serializes [BetaThreadMessageListParams]'s query parameters as
17021702
// `url.Values`.
1703-
func (r BetaThreadMessageListParams) URLQuery() (v url.Values) {
1703+
func (r BetaThreadMessageListParams) URLQuery() (v url.Values, err error) {
17041704
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
17051705
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
17061706
NestedFormat: apiquery.NestedQueryFormatBrackets,

Diff for: betathreadrun.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ func (r BetaThreadRunNewParams) MarshalJSON() (data []byte, err error) {
698698
}
699699

700700
// URLQuery serializes [BetaThreadRunNewParams]'s query parameters as `url.Values`.
701-
func (r BetaThreadRunNewParams) URLQuery() (v url.Values) {
701+
func (r BetaThreadRunNewParams) URLQuery() (v url.Values, err error) {
702702
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
703703
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
704704
NestedFormat: apiquery.NestedQueryFormatBrackets,
@@ -945,7 +945,7 @@ func (f BetaThreadRunListParams) IsPresent() bool { return !param.IsOmitted(f) &
945945

946946
// URLQuery serializes [BetaThreadRunListParams]'s query parameters as
947947
// `url.Values`.
948-
func (r BetaThreadRunListParams) URLQuery() (v url.Values) {
948+
func (r BetaThreadRunListParams) URLQuery() (v url.Values, err error) {
949949
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
950950
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
951951
NestedFormat: apiquery.NestedQueryFormatBrackets,

Diff for: betathreadrunstep.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -1301,7 +1301,7 @@ func (f BetaThreadRunStepGetParams) IsPresent() bool { return !param.IsOmitted(f
13011301

13021302
// URLQuery serializes [BetaThreadRunStepGetParams]'s query parameters as
13031303
// `url.Values`.
1304-
func (r BetaThreadRunStepGetParams) URLQuery() (v url.Values) {
1304+
func (r BetaThreadRunStepGetParams) URLQuery() (v url.Values, err error) {
13051305
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
13061306
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
13071307
NestedFormat: apiquery.NestedQueryFormatBrackets,
@@ -1344,7 +1344,7 @@ func (f BetaThreadRunStepListParams) IsPresent() bool { return !param.IsOmitted(
13441344

13451345
// URLQuery serializes [BetaThreadRunStepListParams]'s query parameters as
13461346
// `url.Values`.
1347-
func (r BetaThreadRunStepListParams) URLQuery() (v url.Values) {
1347+
func (r BetaThreadRunStepListParams) URLQuery() (v url.Values, err error) {
13481348
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
13491349
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
13501350
NestedFormat: apiquery.NestedQueryFormatBrackets,

Diff for: chatcompletion.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -2677,7 +2677,7 @@ func (f ChatCompletionListParams) IsPresent() bool { return !param.IsOmitted(f)
26772677

26782678
// URLQuery serializes [ChatCompletionListParams]'s query parameters as
26792679
// `url.Values`.
2680-
func (r ChatCompletionListParams) URLQuery() (v url.Values) {
2680+
func (r ChatCompletionListParams) URLQuery() (v url.Values, err error) {
26812681
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
26822682
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
26832683
NestedFormat: apiquery.NestedQueryFormatBrackets,

Diff for: chatcompletionmessage.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func (f ChatCompletionMessageListParams) IsPresent() bool { return !param.IsOmit
8383

8484
// URLQuery serializes [ChatCompletionMessageListParams]'s query parameters as
8585
// `url.Values`.
86-
func (r ChatCompletionMessageListParams) URLQuery() (v url.Values) {
86+
func (r ChatCompletionMessageListParams) URLQuery() (v url.Values, err error) {
8787
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
8888
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
8989
NestedFormat: apiquery.NestedQueryFormatBrackets,

Diff for: file.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ type FileListParams struct {
303303
func (f FileListParams) IsPresent() bool { return !param.IsOmitted(f) && !f.IsNull() }
304304

305305
// URLQuery serializes [FileListParams]'s query parameters as `url.Values`.
306-
func (r FileListParams) URLQuery() (v url.Values) {
306+
func (r FileListParams) URLQuery() (v url.Values, err error) {
307307
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
308308
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
309309
NestedFormat: apiquery.NestedQueryFormatBrackets,

Diff for: finetuningjob.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -1498,7 +1498,7 @@ func (f FineTuningJobListParams) IsPresent() bool { return !param.IsOmitted(f) &
14981498

14991499
// URLQuery serializes [FineTuningJobListParams]'s query parameters as
15001500
// `url.Values`.
1501-
func (r FineTuningJobListParams) URLQuery() (v url.Values) {
1501+
func (r FineTuningJobListParams) URLQuery() (v url.Values, err error) {
15021502
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
15031503
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
15041504
NestedFormat: apiquery.NestedQueryFormatBrackets,
@@ -1519,7 +1519,7 @@ func (f FineTuningJobListEventsParams) IsPresent() bool { return !param.IsOmitte
15191519

15201520
// URLQuery serializes [FineTuningJobListEventsParams]'s query parameters as
15211521
// `url.Values`.
1522-
func (r FineTuningJobListEventsParams) URLQuery() (v url.Values) {
1522+
func (r FineTuningJobListEventsParams) URLQuery() (v url.Values, err error) {
15231523
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
15241524
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
15251525
NestedFormat: apiquery.NestedQueryFormatBrackets,

Diff for: finetuningjobcheckpoint.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ func (f FineTuningJobCheckpointListParams) IsPresent() bool {
149149

150150
// URLQuery serializes [FineTuningJobCheckpointListParams]'s query parameters as
151151
// `url.Values`.
152-
func (r FineTuningJobCheckpointListParams) URLQuery() (v url.Values) {
152+
func (r FineTuningJobCheckpointListParams) URLQuery() (v url.Values, err error) {
153153
return apiquery.MarshalWithSettings(r, apiquery.QuerySettings{
154154
ArrayFormat: apiquery.ArrayQueryFormatBrackets,
155155
NestedFormat: apiquery.NestedQueryFormatBrackets,

Diff for: internal/apiform/encoder.go

+35-3
Original file line numberDiff line numberDiff line change
@@ -226,8 +226,14 @@ func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
226226
return e.newFieldTypeEncoder(t)
227227
}
228228

229-
if idx, ok := param.OptionalPrimitiveTypes[t]; ok {
230-
return e.newRichFieldTypeEncoder(t, idx)
229+
if t.Implements(reflect.TypeOf((*param.Optional)(nil)).Elem()) {
230+
return e.newRichFieldTypeEncoder(t)
231+
}
232+
233+
for i := 0; i < t.NumField(); i++ {
234+
if t.Field(i).Type == paramUnionType && t.Field(i).Anonymous {
235+
return e.newStructUnionTypeEncoder(t)
236+
}
231237
}
232238

233239
encoderFields := []encoderField{}
@@ -325,6 +331,32 @@ func (e *encoder) newStructTypeEncoder(t reflect.Type) encoderFunc {
325331
}
326332
}
327333

334+
var paramUnionType = reflect.TypeOf((*param.APIUnion)(nil)).Elem()
335+
336+
func (e *encoder) newStructUnionTypeEncoder(t reflect.Type) encoderFunc {
337+
var fieldEncoders []encoderFunc
338+
for i := 0; i < t.NumField(); i++ {
339+
field := t.Field(i)
340+
if field.Type == paramUnionType && field.Anonymous {
341+
fieldEncoders = append(fieldEncoders, nil)
342+
continue
343+
}
344+
fieldEncoders = append(fieldEncoders, e.typeEncoder(field.Type))
345+
}
346+
347+
return func(key string, value reflect.Value, writer *multipart.Writer) error {
348+
for i := 0; i < t.NumField(); i++ {
349+
if value.Field(i).Type() == paramUnionType {
350+
continue
351+
}
352+
if !value.Field(i).IsZero() {
353+
return fieldEncoders[i](key, value.Field(i), writer)
354+
}
355+
}
356+
return fmt.Errorf("apiform: union %s has no field set", t.String())
357+
}
358+
}
359+
328360
func (e *encoder) newFieldTypeEncoder(t reflect.Type) encoderFunc {
329361
f, _ := t.FieldByName("Value")
330362
enc := e.typeEncoder(f.Type)
@@ -435,7 +467,7 @@ func (e *encoder) encodeMapEntries(key string, v reflect.Value, writer *multipar
435467
return nil
436468
}
437469

438-
func (e *encoder) newMapEncoder(t reflect.Type) encoderFunc {
470+
func (e *encoder) newMapEncoder(_ reflect.Type) encoderFunc {
439471
return func(key string, value reflect.Value, writer *multipart.Writer) error {
440472
return e.encodeMapEntries(key, value, writer)
441473
}

Diff for: internal/apiform/form_test.go

+73
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package apiform
22

33
import (
44
"bytes"
5+
"github.com/openai/openai-go/packages/param"
56
"mime/multipart"
67
"strings"
78
"testing"
@@ -103,6 +104,23 @@ func (UnionTime) union() {}
103104
type ReaderStruct struct {
104105
}
105106

107+
type NamedEnum string
108+
109+
const NamedEnumFoo NamedEnum = "foo"
110+
111+
type StructUnionWrapper struct {
112+
Union StructUnion `form:"union"`
113+
}
114+
115+
type StructUnion struct {
116+
OfInt param.Opt[int64] `form:",omitzero,inline"`
117+
OfString param.Opt[string] `form:",omitzero,inline"`
118+
OfEnum param.Opt[NamedEnum] `form:",omitzero,inline"`
119+
OfA UnionStructA `form:",omitzero,inline"`
120+
OfB UnionStructB `form:",omitzero,inline"`
121+
param.APIUnion
122+
}
123+
106124
var tests = map[string]struct {
107125
buf string
108126
val interface{}
@@ -375,6 +393,18 @@ bar
375393
},
376394
},
377395

396+
"struct_union_integer": {
397+
`--xxx
398+
Content-Disposition: form-data; name="union"
399+
400+
12
401+
--xxx--
402+
`,
403+
StructUnionWrapper{
404+
Union: StructUnion{OfInt: param.NewOpt[int64](12)},
405+
},
406+
},
407+
378408
"union_integer": {
379409
`--xxx
380410
Content-Disposition: form-data; name="union"
@@ -387,6 +417,30 @@ Content-Disposition: form-data; name="union"
387417
},
388418
},
389419

420+
"struct_union_struct_discriminated_a": {
421+
`--xxx
422+
Content-Disposition: form-data; name="union.a"
423+
424+
foo
425+
--xxx
426+
Content-Disposition: form-data; name="union.b"
427+
428+
bar
429+
--xxx
430+
Content-Disposition: form-data; name="union.type"
431+
432+
typeA
433+
--xxx--
434+
`,
435+
StructUnionWrapper{
436+
Union: StructUnion{OfA: UnionStructA{
437+
Type: "typeA",
438+
A: "foo",
439+
B: "bar",
440+
}},
441+
},
442+
},
443+
390444
"union_struct_discriminated_a": {
391445
`--xxx
392446
Content-Disposition: form-data; name="union.a"
@@ -412,6 +466,25 @@ typeA
412466
},
413467
},
414468

469+
"struct_union_struct_discriminated_b": {
470+
`--xxx
471+
Content-Disposition: form-data; name="union.a"
472+
473+
foo
474+
--xxx
475+
Content-Disposition: form-data; name="union.type"
476+
477+
typeB
478+
--xxx--
479+
`,
480+
StructUnionWrapper{
481+
Union: StructUnion{OfB: UnionStructB{
482+
Type: "typeB",
483+
A: "foo",
484+
}},
485+
},
486+
},
487+
415488
"union_struct_discriminated_b": {
416489
`--xxx
417490
Content-Disposition: form-data; name="union.a"

Diff for: internal/apiform/richparam.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ import (
66
"reflect"
77
)
88

9-
func (e *encoder) newRichFieldTypeEncoder(t reflect.Type, underlyingValueIdx []int) encoderFunc {
10-
underlying := t.FieldByIndex(underlyingValueIdx)
11-
primitiveEncoder := e.newPrimitiveTypeEncoder(underlying.Type)
9+
func (e *encoder) newRichFieldTypeEncoder(t reflect.Type) encoderFunc {
10+
f, _ := t.FieldByName("Value")
11+
enc := e.newPrimitiveTypeEncoder(f.Type)
1212
return func(key string, value reflect.Value, writer *multipart.Writer) error {
1313
if opt, ok := value.Interface().(param.Optional); ok && opt.IsPresent() {
14-
return primitiveEncoder(key, value.FieldByIndex(underlyingValueIdx), writer)
14+
return enc(key, value.FieldByIndex(f.Index), writer)
1515
} else if ok && opt.IsNull() {
1616
return writer.WriteField(key, "null")
1717
}

0 commit comments

Comments
 (0)