Skip to content

Commit 533f0ea

Browse files
author
Yoshio Terada
committed
First commit to fix Issue openai#211
First commit to fix Issue openai#211 This commit includes the fix described in Issue openai#211. * Addressed the issue where Base64 encoding could not be handled. * Improved performance by using Base64 encoding by default.
1 parent d742459 commit 533f0ea

File tree

6 files changed

+189
-23
lines changed

6 files changed

+189
-23
lines changed

openai-java-core/src/main/kotlin/com/openai/models/embeddings/Embedding.kt

+38-17
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import com.openai.core.immutableEmptyMap
1717
import com.openai.core.toImmutable
1818
import com.openai.errors.OpenAIInvalidDataException
1919
import java.util.Objects
20+
import java.util.Optional
2021

2122
/** Represents an embedding vector returned by embedding endpoint. */
2223
@NoAutoDetect
@@ -25,7 +26,7 @@ class Embedding
2526
private constructor(
2627
@JsonProperty("embedding")
2728
@ExcludeMissing
28-
private val embedding: JsonField<List<Double>> = JsonMissing.of(),
29+
private val embedding: JsonField<EmbeddingValue> = JsonMissing.of(),
2930
@JsonProperty("index") @ExcludeMissing private val index: JsonField<Long> = JsonMissing.of(),
3031
@JsonProperty("object") @ExcludeMissing private val object_: JsonValue = JsonMissing.of(),
3132
@JsonAnySetter private val additionalProperties: Map<String, JsonValue> = immutableEmptyMap(),
@@ -35,7 +36,7 @@ private constructor(
3536
* The embedding vector, which is a list of floats. The length of vector depends on the model as
3637
* listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings).
3738
*/
38-
fun embedding(): List<Double> = embedding.getRequired("embedding")
39+
fun embedding(): EmbeddingValue = embedding.getRequired("embedding")
3940

4041
/** The index of the embedding in the list of embeddings. */
4142
fun index(): Long = index.getRequired("index")
@@ -47,7 +48,9 @@ private constructor(
4748
* The embedding vector, which is a list of floats. The length of vector depends on the model as
4849
* listed in the [embedding guide](https://platform.openai.com/docs/guides/embeddings).
4950
*/
50-
@JsonProperty("embedding") @ExcludeMissing fun _embedding(): JsonField<List<Double>> = embedding
51+
@JsonProperty("embedding")
52+
@ExcludeMissing
53+
fun _embedding(): JsonField<EmbeddingValue> = embedding
5154

5255
/** The index of the embedding in the list of embeddings. */
5356
@JsonProperty("index") @ExcludeMissing fun _index(): JsonField<Long> = index
@@ -92,14 +95,21 @@ private constructor(
9295
/** A builder for [Embedding]. */
9396
class Builder internal constructor() {
9497

95-
private var embedding: JsonField<MutableList<Double>>? = null
98+
private var embedding: JsonField<EmbeddingValue>? = null
9699
private var index: JsonField<Long>? = null
97100
private var object_: JsonValue = JsonValue.from("embedding")
98101
private var additionalProperties: MutableMap<String, JsonValue> = mutableMapOf()
99102

100103
@JvmSynthetic
101104
internal fun from(embedding: Embedding) = apply {
102-
this.embedding = embedding.embedding.map { it.toMutableList() }
105+
this.embedding =
106+
embedding.embedding.map {
107+
EmbeddingValue(
108+
floatEmbedding =
109+
Optional.of(it.floatEmbedding.orElse(mutableListOf()).toMutableList()),
110+
base64Embedding = it.base64Embedding,
111+
)
112+
}
103113
index = embedding.index
104114
object_ = embedding.object_
105115
additionalProperties = embedding.additionalProperties.toMutableMap()
@@ -110,27 +120,32 @@ private constructor(
110120
* model as listed in the
111121
* [embedding guide](https://platform.openai.com/docs/guides/embeddings).
112122
*/
113-
fun embedding(embedding: List<Double>) = embedding(JsonField.of(embedding))
123+
fun embedding(embedding: EmbeddingValue) = embedding(JsonField.of(embedding))
114124

115125
/**
116-
* The embedding vector, which is a list of floats. The length of vector depends on the
117-
* model as listed in the
126+
* The embedding vector, which is a list of floats or Base64. The float length of vector
127+
* depends on the model as listed in the
118128
* [embedding guide](https://platform.openai.com/docs/guides/embeddings).
119129
*/
120-
fun embedding(embedding: JsonField<List<Double>>) = apply {
121-
this.embedding = embedding.map { it.toMutableList() }
130+
fun embedding(embedding: JsonField<EmbeddingValue>) = apply {
131+
this.embedding =
132+
embedding.map {
133+
EmbeddingValue(
134+
floatEmbedding =
135+
Optional.of(it.floatEmbedding.orElse(mutableListOf()).toMutableList()),
136+
base64Embedding = it.base64Embedding,
137+
)
138+
}
122139
}
123140

124141
/**
125-
* The embedding vector, which is a list of floats. The length of vector depends on the
126-
* model as listed in the
142+
* The embedding vector, which is a list of floats or Base64. The float length of vector
143+
* depends on the model as listed in the
127144
* [embedding guide](https://platform.openai.com/docs/guides/embeddings).
128145
*/
129-
fun addEmbedding(embedding: Double) = apply {
146+
fun addEmbedding(embedding: EmbeddingValue) = apply {
130147
this.embedding =
131-
(this.embedding ?: JsonField.of(mutableListOf())).also {
132-
checkKnown("embedding", it).add(embedding)
133-
}
148+
(this.embedding ?: JsonField.of(embedding)).also { checkKnown("embedding", it) }
134149
}
135150

136151
/** The index of the embedding in the list of embeddings. */
@@ -163,7 +178,13 @@ private constructor(
163178

164179
fun build(): Embedding =
165180
Embedding(
166-
checkRequired("embedding", embedding).map { it.toImmutable() },
181+
checkRequired("embedding", embedding).map {
182+
EmbeddingValue(
183+
floatEmbedding =
184+
Optional.of(it.floatEmbedding.orElse(mutableListOf()).toMutableList()),
185+
base64Embedding = it.base64Embedding,
186+
)
187+
},
167188
checkRequired("index", index),
168189
object_,
169190
additionalProperties.toImmutable(),

openai-java-core/src/main/kotlin/com/openai/models/embeddings/EmbeddingCreateParams.kt

+3-1
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,9 @@ private constructor(
271271
private var input: JsonField<Input>? = null
272272
private var model: JsonField<EmbeddingModel>? = null
273273
private var dimensions: JsonField<Long> = JsonMissing.of()
274-
private var encodingFormat: JsonField<EncodingFormat> = JsonMissing.of()
274+
// Default EncodingFormat value is set to BASE64 for performance improvements.
275+
private var encodingFormat: JsonField<EncodingFormat> =
276+
JsonField.of(EncodingFormat.BASE64)
275277
private var user: JsonField<String> = JsonMissing.of()
276278
private var additionalProperties: MutableMap<String, JsonValue> = mutableMapOf()
277279

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
package com.openai.models.embeddings
2+
3+
import com.fasterxml.jackson.databind.annotation.JsonDeserialize
4+
import java.nio.ByteBuffer
5+
import java.nio.ByteOrder
6+
import java.util.Base64
7+
import java.util.Optional
8+
import kotlin.collections.MutableList
9+
10+
/** Represents an embedding vector returned by embedding endpoint. */
11+
@JsonDeserialize(using = EmbeddingValueDeserializer::class)
12+
class EmbeddingValue(
13+
var base64Embedding: Optional<String> = Optional.empty(),
14+
floatEmbedding: Optional<MutableList<Double>> = Optional.empty(),
15+
) {
16+
17+
/**
18+
* The embedding vector, which is a list of float32.
19+
* [embedding guide](https://platform.openai.com/docs/guides/embeddings).
20+
*/
21+
var floatEmbedding: Optional<MutableList<Double>> = Optional.empty()
22+
get() {
23+
if (field.isPresent) {
24+
return field
25+
}
26+
if (base64Embedding.isPresent) {
27+
field = convertBase64ToFloat(base64Embedding)
28+
}
29+
return field
30+
}
31+
set(value) {
32+
field = value
33+
}
34+
35+
/**
36+
* Converting Base64 float32 array to Optional<MutableList>
37+
*
38+
* To improve performance, requests are made in Base64 by default. However, not all developers
39+
* need to decode Base64. Therefore, when a request is made in Base64, the system will
40+
* internally convert the Base64 data to MutableList<Double> and make this converted data
41+
* available, allowing developers to obtain both the Base64 data and the MutableList<Double>
42+
* data by default.
43+
*/
44+
private fun convertBase64ToFloat(
45+
base64Embedding: Optional<String>
46+
): Optional<MutableList<Double>> {
47+
// The response of Embedding returns a List<Float>(float32),
48+
// but the Kotlin API handles MutableList<Double>.
49+
// If we directly convert from List<Float> to MutableList<Double>,
50+
// it increases the precision and changing it from float32 to double.
51+
//
52+
// Since JSON is assigned to MutableList<Double> from a String of JSON Value,
53+
// the precision does not increase.
54+
// Therefore, by first converting the Base64-decoded List<Float> to a String,
55+
// and then converting the String to Double,
56+
// we can handle it as MutableList<Double> without increasing the precision.
57+
return base64Embedding.map { base64String ->
58+
val decoded = Base64.getDecoder().decode(base64String)
59+
val byteBuffer = ByteBuffer.wrap(decoded).order(ByteOrder.LITTLE_ENDIAN)
60+
61+
val floatList = mutableListOf<String>()
62+
while (byteBuffer.hasRemaining()) {
63+
floatList.add(byteBuffer.float.toString())
64+
}
65+
floatList.map { it.replace("f", "").toDouble() }.toMutableList()
66+
}
67+
}
68+
69+
/**
70+
* Output the embedding vector as a string. By default, it will be output as both list of floats
71+
* and Base64 string. if user specifies floatEmbedding, it will be output as list of floats
72+
* only.
73+
*/
74+
override fun toString(): String {
75+
return if (base64Embedding.isPresent) {
76+
"base64: $base64Embedding, float: [${floatEmbedding.get().joinToString(", ")}]"
77+
} else {
78+
"float: [${floatEmbedding.get().joinToString(", ")}]"
79+
}
80+
}
81+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package com.openai.models.embeddings
2+
3+
import com.fasterxml.jackson.core.JsonParser
4+
import com.fasterxml.jackson.databind.DeserializationContext
5+
import com.fasterxml.jackson.databind.JsonDeserializer
6+
import com.fasterxml.jackson.databind.JsonNode
7+
import com.fasterxml.jackson.databind.node.ArrayNode
8+
import java.io.IOException
9+
import java.util.Optional
10+
11+
/** JsonDeserializer for EmbeddingValue */
12+
class EmbeddingValueDeserializer : JsonDeserializer<EmbeddingValue>() {
13+
@Throws(IOException::class)
14+
15+
/*
16+
* Deserialize the JSON representation of an EmbeddingValue.
17+
* The JSON can either be an array of floats or a base64 string.
18+
*/
19+
override fun deserialize(jp: JsonParser, ctxt: DeserializationContext): EmbeddingValue {
20+
val node = jp.codec.readTree<JsonNode>(jp)
21+
val embeddingValue = EmbeddingValue()
22+
23+
if (node.isArray) {
24+
val floats = mutableListOf<Double>()
25+
(node as ArrayNode).forEach { item -> floats.add(item.asDouble()) }
26+
embeddingValue.floatEmbedding = Optional.of(floats)
27+
} else if (node.isTextual) {
28+
embeddingValue.base64Embedding = Optional.of(node.asText())
29+
}
30+
return embeddingValue
31+
}
32+
}

openai-java-core/src/test/kotlin/com/openai/models/embeddings/CreateEmbeddingResponseTest.kt

+23-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
package com.openai.models.embeddings
44

5+
import java.util.Optional
56
import org.assertj.core.api.Assertions.assertThat
67
import org.junit.jupiter.api.Test
78

@@ -11,15 +12,35 @@ class CreateEmbeddingResponseTest {
1112
fun createCreateEmbeddingResponse() {
1213
val createEmbeddingResponse =
1314
CreateEmbeddingResponse.builder()
14-
.addData(Embedding.builder().addEmbedding(0.0).index(0L).build())
15+
.addData(
16+
Embedding.builder()
17+
.addEmbedding(
18+
EmbeddingValue(
19+
floatEmbedding = Optional.of(mutableListOf(0.0)),
20+
base64Embedding = Optional.empty(),
21+
)
22+
)
23+
.index(0L)
24+
.build()
25+
)
1526
.model("model")
1627
.usage(
1728
CreateEmbeddingResponse.Usage.builder().promptTokens(0L).totalTokens(0L).build()
1829
)
1930
.build()
2031
assertThat(createEmbeddingResponse).isNotNull
2132
assertThat(createEmbeddingResponse.data())
22-
.containsExactly(Embedding.builder().addEmbedding(0.0).index(0L).build())
33+
.containsExactly(
34+
Embedding.builder()
35+
.addEmbedding(
36+
EmbeddingValue(
37+
floatEmbedding = Optional.of(mutableListOf(0.0)),
38+
base64Embedding = Optional.empty(),
39+
)
40+
)
41+
.index(0L)
42+
.build()
43+
)
2344
assertThat(createEmbeddingResponse.model()).isEqualTo("model")
2445
assertThat(createEmbeddingResponse.usage())
2546
.isEqualTo(
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,26 @@
11
// File generated from our OpenAPI spec by Stainless.
22

33
package com.openai.models.embeddings
4-
4+
import java.util.Optional
55
import org.assertj.core.api.Assertions.assertThat
66
import org.junit.jupiter.api.Test
77

88
class EmbeddingTest {
99

1010
@Test
1111
fun createEmbedding() {
12-
val embedding = Embedding.builder().addEmbedding(0.0).index(0L).build()
12+
val embedding =
13+
Embedding.builder()
14+
.addEmbedding(
15+
EmbeddingValue(
16+
floatEmbedding = Optional.of(mutableListOf(0.0)),
17+
base64Embedding = Optional.empty(),
18+
)
19+
)
20+
.build()
1321
assertThat(embedding).isNotNull
14-
assertThat(embedding.embedding()).containsExactly(0.0)
22+
// assertThat(embedding.embedding()).containsExactly(0.0)
23+
assertThat(embedding.embedding().floatEmbedding).containsSame(mutableListOf(0.0))
1524
assertThat(embedding.index()).isEqualTo(0L)
1625
}
1726
}

0 commit comments

Comments
 (0)