@@ -6,14 +6,18 @@ import com.google.cloud.vertexai.api.{
6
6
FileData ,
7
7
GenerateContentResponse ,
8
8
GenerationConfig ,
9
- Part
9
+ Part ,
10
+ Schema ,
11
+ Type
10
12
}
13
+ import io .cequence .openaiscala .OpenAIScalaClientException
11
14
import io .cequence .openaiscala .domain .{
12
15
AssistantMessage ,
13
16
BaseMessage ,
14
17
ChatRole ,
15
18
DeveloperMessage ,
16
19
ImageURLContent ,
20
+ JsonSchema ,
17
21
MessageSpec ,
18
22
SystemMessage ,
19
23
TextContent ,
@@ -25,11 +29,15 @@ import io.cequence.openaiscala.domain.response.{
25
29
ChatCompletionResponse ,
26
30
UsageInfo => OpenAIUsageInfo
27
31
}
28
- import io .cequence .openaiscala .domain .settings .CreateChatCompletionSettings
32
+ import io .cequence .openaiscala .domain .settings .{
33
+ ChatCompletionResponseFormatType ,
34
+ CreateChatCompletionSettings
35
+ }
29
36
30
37
import java .{util => ju }
31
- import scala .collection .convert .ImplicitConversions .`list asScalaBuffer`
32
38
import scala .collection .convert .ImplicitConversions .`iterable asJava`
39
+ import scala .collection .convert .ImplicitConversions .`map AsJavaMap`
40
+ import scala .collection .convert .ImplicitConversions .`list asScalaBuffer`
33
41
34
42
package object impl {
35
43
@@ -43,32 +51,30 @@ package object impl {
43
51
.build()
44
52
45
53
case UserSeqMessage (contents, _) =>
46
- val parts = contents.map { content =>
47
- content match {
48
- case TextContent (text) =>
49
- Part .newBuilder().setText(text).build()
50
-
51
- case ImageURLContent (url) =>
52
- if (url.startsWith(" data:" )) {
53
- val mediaTypeEncodingAndData = url.drop(5 )
54
- val mediaType = mediaTypeEncodingAndData.takeWhile(_ != ';' )
55
- val encodingAndData = mediaTypeEncodingAndData.drop(mediaType.length + 1 )
56
- val encoding = mediaType.takeWhile(_ != ',' )
57
- val data = encodingAndData.drop(encoding.length + 1 )
58
-
59
- // TODO: try this
60
- Part
61
- .newBuilder()
62
- .setFileData(
63
- FileData .newBuilder().setMimeType(mediaType).setFileUri(data).build()
64
- )
65
- .build()
66
- } else {
67
- throw new IllegalArgumentException (
68
- " Image content only supported by providing image data directly. Must start with 'data:'."
54
+ val parts = contents.map {
55
+ case TextContent (text) =>
56
+ Part .newBuilder().setText(text).build()
57
+
58
+ case ImageURLContent (url) =>
59
+ if (url.startsWith(" data:" )) {
60
+ val mediaTypeEncodingAndData = url.drop(5 )
61
+ val mediaType = mediaTypeEncodingAndData.takeWhile(_ != ';' )
62
+ val encodingAndData = mediaTypeEncodingAndData.drop(mediaType.length + 1 )
63
+ val encoding = mediaType.takeWhile(_ != ',' )
64
+ val data = encodingAndData.drop(encoding.length + 1 )
65
+
66
+ // TODO: try this
67
+ Part
68
+ .newBuilder()
69
+ .setFileData(
70
+ FileData .newBuilder().setMimeType(mediaType).setFileUri(data).build()
69
71
)
70
- }
71
- }
72
+ .build()
73
+ } else {
74
+ throw new IllegalArgumentException (
75
+ " Image content only supported by providing image data directly. Must start with 'data:'."
76
+ )
77
+ }
72
78
}
73
79
74
80
val contentBuilder = Content .newBuilder().setRole(" USER" )
@@ -167,9 +173,84 @@ package object impl {
167
173
if (settings.stop.nonEmpty) Some (`iterable asJava`(settings.stop)) else None
168
174
)
169
175
176
+ // handle json schema
177
+ val responseFormat =
178
+ settings.response_format_type.getOrElse(ChatCompletionResponseFormatType .text)
179
+
180
+ val jsonSchema =
181
+ if (
182
+ responseFormat == ChatCompletionResponseFormatType .json_schema && settings.jsonSchema.isDefined
183
+ ) {
184
+ settings.jsonSchema.get.structure match {
185
+ case Left (schema) =>
186
+ Some (toVertexJSONSchema(schema))
187
+ case Right (_) =>
188
+ None
189
+ }
190
+ } else
191
+ None
192
+
193
+ jsonSchema.foreach { schema =>
194
+ configBuilder.setResponseSchema(schema)
195
+ configBuilder.setResponseMimeType(" application/json" )
196
+ }
197
+
170
198
configBuilder.build()
171
199
}
172
200
201
+ private def toVertexJSONSchema (
202
+ jsonSchema : JsonSchema
203
+ ): Schema = {
204
+ val builder = Schema .newBuilder()
205
+
206
+ jsonSchema match {
207
+ case JsonSchema .String (description, enumVals) =>
208
+ builder.setType(Type .STRING )
209
+ description.foreach(builder.setDescription)
210
+ enumVals.foreach(builder.addEnum)
211
+
212
+ case JsonSchema .Number (description) =>
213
+ val b = builder.setType(Type .NUMBER )
214
+ description.foreach(b.setDescription)
215
+
216
+ case JsonSchema .Integer (description) =>
217
+ val b = builder.setType(Type .INTEGER )
218
+ description.foreach(b.setDescription)
219
+
220
+ case JsonSchema .Boolean (description) =>
221
+ val b = builder.setType(Type .BOOLEAN )
222
+ description.foreach(b.setDescription)
223
+
224
+ case JsonSchema .Null () =>
225
+ builder.setType(Type .TYPE_UNSPECIFIED )
226
+
227
+ case JsonSchema .Object (properties, required) =>
228
+ val b = builder.setType(Type .OBJECT )
229
+ if (properties.nonEmpty) {
230
+ val propsMap = properties.map { case (key, jsonSchema) =>
231
+ key -> toVertexJSONSchema(jsonSchema)
232
+ }.toMap
233
+
234
+ b.putAllProperties(`map AsJavaMap`(propsMap))
235
+ }
236
+
237
+ if (required.nonEmpty) {
238
+ b.addAllRequired(`iterable asJava`(required))
239
+ }
240
+
241
+ case JsonSchema .Array (items) =>
242
+ val b = builder.setType(Type .ARRAY )
243
+ b.setItems(toVertexJSONSchema(items))
244
+
245
+ case _ =>
246
+ throw new OpenAIScalaClientException (
247
+ s " Unsupported JSON schema type for Google Vertex. "
248
+ )
249
+ }
250
+
251
+ builder.build()
252
+ }
253
+
173
254
def toOpenAI (
174
255
response : GenerateContentResponse ,
175
256
model : String
0 commit comments