Skip to content

Commit ceece5f

Browse files
committed
Google vertex - json schema support
1 parent 5773e7f commit ceece5f

File tree

10 files changed

+223
-69
lines changed

10 files changed

+223
-69
lines changed

anthropic-client/src/main/scala/io/cequence/openaiscala/anthropic/service/impl/AnthropicServiceImpl.scala

-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ import io.cequence.openaiscala.anthropic.domain.response.{
99
import io.cequence.openaiscala.anthropic.domain.settings.AnthropicCreateMessageSettings
1010
import io.cequence.openaiscala.anthropic.domain.Message
1111
import io.cequence.wsclient.ResponseImplicits.JsonSafeOps
12-
import play.api.libs.json.Json
1312

1413
import scala.concurrent.Future
1514

google-gemini-client/src/main/scala/io/cequence/openaiscala/gemini/service/impl/OpenAIGeminiChatCompletionService.scala

+5
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,11 @@ private[service] class OpenAIGeminiChatCompletionService(
322322
description = description
323323
)
324324

325+
case JsonSchema.Null() =>
326+
Schema(
327+
`type` = SchemaType.TYPE_UNSPECIFIED
328+
)
329+
325330
case JsonSchema.Object(properties, required) =>
326331
Schema(
327332
`type` = SchemaType.OBJECT,

google-vertexai-client/src/main/scala/io/cequence/openaiscala/vertexai/service/impl/package.scala

+109-28
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,18 @@ import com.google.cloud.vertexai.api.{
66
FileData,
77
GenerateContentResponse,
88
GenerationConfig,
9-
Part
9+
Part,
10+
Schema,
11+
Type
1012
}
13+
import io.cequence.openaiscala.OpenAIScalaClientException
1114
import io.cequence.openaiscala.domain.{
1215
AssistantMessage,
1316
BaseMessage,
1417
ChatRole,
1518
DeveloperMessage,
1619
ImageURLContent,
20+
JsonSchema,
1721
MessageSpec,
1822
SystemMessage,
1923
TextContent,
@@ -25,11 +29,15 @@ import io.cequence.openaiscala.domain.response.{
2529
ChatCompletionResponse,
2630
UsageInfo => OpenAIUsageInfo
2731
}
28-
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
32+
import io.cequence.openaiscala.domain.settings.{
33+
ChatCompletionResponseFormatType,
34+
CreateChatCompletionSettings
35+
}
2936

3037
import java.{util => ju}
31-
import scala.collection.convert.ImplicitConversions.`list asScalaBuffer`
3238
import scala.collection.convert.ImplicitConversions.`iterable asJava`
39+
import scala.collection.convert.ImplicitConversions.`map AsJavaMap`
40+
import scala.collection.convert.ImplicitConversions.`list asScalaBuffer`
3341

3442
package object impl {
3543

@@ -43,32 +51,30 @@ package object impl {
4351
.build()
4452

4553
case UserSeqMessage(contents, _) =>
46-
val parts = contents.map { content =>
47-
content match {
48-
case TextContent(text) =>
49-
Part.newBuilder().setText(text).build()
50-
51-
case ImageURLContent(url) =>
52-
if (url.startsWith("data:")) {
53-
val mediaTypeEncodingAndData = url.drop(5)
54-
val mediaType = mediaTypeEncodingAndData.takeWhile(_ != ';')
55-
val encodingAndData = mediaTypeEncodingAndData.drop(mediaType.length + 1)
56-
val encoding = mediaType.takeWhile(_ != ',')
57-
val data = encodingAndData.drop(encoding.length + 1)
58-
59-
// TODO: try this
60-
Part
61-
.newBuilder()
62-
.setFileData(
63-
FileData.newBuilder().setMimeType(mediaType).setFileUri(data).build()
64-
)
65-
.build()
66-
} else {
67-
throw new IllegalArgumentException(
68-
"Image content only supported by providing image data directly. Must start with 'data:'."
54+
val parts = contents.map {
55+
case TextContent(text) =>
56+
Part.newBuilder().setText(text).build()
57+
58+
case ImageURLContent(url) =>
59+
if (url.startsWith("data:")) {
60+
val mediaTypeEncodingAndData = url.drop(5)
61+
val mediaType = mediaTypeEncodingAndData.takeWhile(_ != ';')
62+
val encodingAndData = mediaTypeEncodingAndData.drop(mediaType.length + 1)
63+
val encoding = mediaType.takeWhile(_ != ',')
64+
val data = encodingAndData.drop(encoding.length + 1)
65+
66+
// TODO: try this
67+
Part
68+
.newBuilder()
69+
.setFileData(
70+
FileData.newBuilder().setMimeType(mediaType).setFileUri(data).build()
6971
)
70-
}
71-
}
72+
.build()
73+
} else {
74+
throw new IllegalArgumentException(
75+
"Image content only supported by providing image data directly. Must start with 'data:'."
76+
)
77+
}
7278
}
7379

7480
val contentBuilder = Content.newBuilder().setRole("USER")
@@ -167,9 +173,84 @@ package object impl {
167173
if (settings.stop.nonEmpty) Some(`iterable asJava`(settings.stop)) else None
168174
)
169175

176+
// handle json schema
177+
val responseFormat =
178+
settings.response_format_type.getOrElse(ChatCompletionResponseFormatType.text)
179+
180+
val jsonSchema =
181+
if (
182+
responseFormat == ChatCompletionResponseFormatType.json_schema && settings.jsonSchema.isDefined
183+
) {
184+
settings.jsonSchema.get.structure match {
185+
case Left(schema) =>
186+
Some(toVertexJSONSchema(schema))
187+
case Right(_) =>
188+
None
189+
}
190+
} else
191+
None
192+
193+
jsonSchema.foreach { schema =>
194+
configBuilder.setResponseSchema(schema)
195+
configBuilder.setResponseMimeType("application/json")
196+
}
197+
170198
configBuilder.build()
171199
}
172200

201+
private def toVertexJSONSchema(
202+
jsonSchema: JsonSchema
203+
): Schema = {
204+
val builder = Schema.newBuilder()
205+
206+
jsonSchema match {
207+
case JsonSchema.String(description, enumVals) =>
208+
builder.setType(Type.STRING)
209+
description.foreach(builder.setDescription)
210+
enumVals.foreach(builder.addEnum)
211+
212+
case JsonSchema.Number(description) =>
213+
val b = builder.setType(Type.NUMBER)
214+
description.foreach(b.setDescription)
215+
216+
case JsonSchema.Integer(description) =>
217+
val b = builder.setType(Type.INTEGER)
218+
description.foreach(b.setDescription)
219+
220+
case JsonSchema.Boolean(description) =>
221+
val b = builder.setType(Type.BOOLEAN)
222+
description.foreach(b.setDescription)
223+
224+
case JsonSchema.Null() =>
225+
builder.setType(Type.TYPE_UNSPECIFIED)
226+
227+
case JsonSchema.Object(properties, required) =>
228+
val b = builder.setType(Type.OBJECT)
229+
if (properties.nonEmpty) {
230+
val propsMap = properties.map { case (key, jsonSchema) =>
231+
key -> toVertexJSONSchema(jsonSchema)
232+
}.toMap
233+
234+
b.putAllProperties(`map AsJavaMap`(propsMap))
235+
}
236+
237+
if (required.nonEmpty) {
238+
b.addAllRequired(`iterable asJava`(required))
239+
}
240+
241+
case JsonSchema.Array(items) =>
242+
val b = builder.setType(Type.ARRAY)
243+
b.setItems(toVertexJSONSchema(items))
244+
245+
case _ =>
246+
throw new OpenAIScalaClientException(
247+
s"Unsupported JSON schema type for Google Vertex."
248+
)
249+
}
250+
251+
builder.build()
252+
}
253+
173254
def toOpenAI(
174255
response: GenerateContentResponse,
175256
model: String

openai-core/src/main/scala/io/cequence/openaiscala/service/OpenAIChatCompletionExtra.scala

+27-27
Original file line numberDiff line numberDiff line change
@@ -146,31 +146,29 @@ object OpenAIChatCompletionExtra {
146146
ModelId.o1,
147147
ModelId.o1_2024_12_17,
148148
ModelId.o3_mini,
149-
ModelId.o3_mini_2025_01_31
150-
).flatMap(id => Seq(id, "openai-" + id, "azure-" + id)) ++
151-
Seq(
152-
NonOpenAIModelId.gemini_2_0_flash,
153-
NonOpenAIModelId.gemini_2_0_flash_001,
154-
NonOpenAIModelId.gemini_2_0_pro_exp_02_05,
155-
NonOpenAIModelId.gemini_2_0_pro_exp,
156-
NonOpenAIModelId.gemini_2_0_flash_001,
157-
NonOpenAIModelId.gemini_2_0_flash,
158-
NonOpenAIModelId.gemini_2_0_flash_exp,
159-
NonOpenAIModelId.gemini_1_5_flash_8b_exp_0924,
160-
NonOpenAIModelId.gemini_1_5_flash_8b_exp_0827,
161-
NonOpenAIModelId.gemini_1_5_flash_8b_latest,
162-
NonOpenAIModelId.gemini_1_5_flash_8b_001,
163-
NonOpenAIModelId.gemini_1_5_flash_8b,
164-
NonOpenAIModelId.gemini_1_5_flash_002,
165-
NonOpenAIModelId.gemini_1_5_flash,
166-
NonOpenAIModelId.gemini_1_5_flash_001,
167-
NonOpenAIModelId.gemini_1_5_flash_latest,
168-
NonOpenAIModelId.gemini_1_5_pro,
169-
NonOpenAIModelId.gemini_1_5_pro_002,
170-
NonOpenAIModelId.gemini_1_5_pro_001,
171-
NonOpenAIModelId.gemini_1_5_pro_latest,
172-
NonOpenAIModelId.gemini_exp_1206
173-
).flatMap(id => Seq(id, "google_gemini-" + id))
149+
ModelId.o3_mini_2025_01_31,
150+
NonOpenAIModelId.gemini_2_0_flash,
151+
NonOpenAIModelId.gemini_2_0_flash_001,
152+
NonOpenAIModelId.gemini_2_0_pro_exp_02_05,
153+
NonOpenAIModelId.gemini_2_0_pro_exp,
154+
NonOpenAIModelId.gemini_2_0_flash_001,
155+
NonOpenAIModelId.gemini_2_0_flash,
156+
NonOpenAIModelId.gemini_2_0_flash_exp,
157+
NonOpenAIModelId.gemini_1_5_flash_8b_exp_0924,
158+
NonOpenAIModelId.gemini_1_5_flash_8b_exp_0827,
159+
NonOpenAIModelId.gemini_1_5_flash_8b_latest,
160+
NonOpenAIModelId.gemini_1_5_flash_8b_001,
161+
NonOpenAIModelId.gemini_1_5_flash_8b,
162+
NonOpenAIModelId.gemini_1_5_flash_002,
163+
NonOpenAIModelId.gemini_1_5_flash,
164+
NonOpenAIModelId.gemini_1_5_flash_001,
165+
NonOpenAIModelId.gemini_1_5_flash_latest,
166+
NonOpenAIModelId.gemini_1_5_pro,
167+
NonOpenAIModelId.gemini_1_5_pro_002,
168+
NonOpenAIModelId.gemini_1_5_pro_001,
169+
NonOpenAIModelId.gemini_1_5_pro_latest,
170+
NonOpenAIModelId.gemini_exp_1206
171+
)
174172

175173
def handleOutputJsonSchema(
176174
messages: Seq[BaseMessage],
@@ -184,8 +182,9 @@ object OpenAIChatCompletionExtra {
184182
val jsonSchemaJson = Json.toJson(jsonSchemaDef.structure)
185183
val jsonSchemaString = Json.prettyPrint(jsonSchemaJson)
186184

187-
val (settingsFinal, addJsonToPrompt) =
188-
if (jsonSchemaModels.contains(settings.model)) {
185+
val (settingsFinal, addJsonToPrompt) = {
186+
// to be more robust we also match models with a suffix
187+
if (jsonSchemaModels.contains((model: String) => (settings.model == model) || (settings.model.endsWith("-" + model)))) {
189188
logger.debug(
190189
s"Using OpenAI json schema mode for ${taskNameForLogging} and the model '${settings.model}' - name: ${jsonSchemaDef.name}, strict: ${jsonSchemaDef.strict}, structure:\n${jsonSchemaString}"
191190
)
@@ -211,6 +210,7 @@ object OpenAIChatCompletionExtra {
211210
true
212211
)
213212
}
213+
}
214214

215215
val messagesFinal = if (addJsonToPrompt) {
216216
if (messages.nonEmpty && messages.last.role == ChatRole.User) {

openai-examples/src/main/scala/io/cequence/openaiscala/examples/googlegemini/GoogleGeminiCreateChatCompletionJSONWithOpenAIAdapter.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ object GoogleGeminiCreateChatCompletionJSONWithOpenAIAdapter
5656
required = Seq("countries")
5757
)
5858

59-
private val modelId = "google_gemini-" + NonOpenAIModelId.gemini_2_0_flash
59+
private val modelId = NonOpenAIModelId.gemini_2_0_flash
6060

6161
override protected def run: Future[_] =
6262
service
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package io.cequence.openaiscala.examples.googlevertexai
2+
3+
import io.cequence.openaiscala.domain._
4+
import io.cequence.openaiscala.domain.settings.{ChatCompletionResponseFormatType, CreateChatCompletionSettings, JsonSchemaDef}
5+
import io.cequence.openaiscala.examples.{ChatCompletionProvider, ExampleBase, TestFixtures}
6+
import io.cequence.openaiscala.service.OpenAIChatCompletionExtra._
7+
import io.cequence.openaiscala.service.OpenAIChatCompletionService
8+
import play.api.libs.json.{JsObject, Json}
9+
10+
import scala.concurrent.Future
11+
12+
// requires `openai-scala-google-vertexai-client` as a dependency and `VERTEXAI_LOCATION` and `VERTEXAI_PROJECT_ID` environments variable to be set
13+
object GoogleVertexAICreateChatCompletionJSONWithOpenAIAdapter
14+
extends ExampleBase[OpenAIChatCompletionService]
15+
with TestFixtures {
16+
17+
override val service: OpenAIChatCompletionService = ChatCompletionProvider.vertexAI
18+
19+
private val messages = Seq(
20+
SystemMessage("You are an expert geographer"),
21+
UserMessage("List all Asian countries in the prescribed JSON format.")
22+
)
23+
24+
private val jsonSchema = JsonSchema.Object(
25+
properties = Seq(
26+
"countries" -> JsonSchema.Array(
27+
JsonSchema.Object(
28+
properties = Seq(
29+
"country" -> JsonSchema.String(),
30+
"capital" -> JsonSchema.String(),
31+
"countrySize" -> JsonSchema.String(
32+
`enum` = Seq("small", "medium", "large")
33+
),
34+
"commonwealthMember" -> JsonSchema.Boolean(),
35+
"populationMil" -> JsonSchema.Integer(),
36+
"ratioOfMenToWomen" -> JsonSchema.Number()
37+
),
38+
required = Seq(
39+
"country",
40+
"capital",
41+
"countrySize",
42+
"commonwealthMember",
43+
"populationMil",
44+
"ratioOfMenToWomen"
45+
)
46+
)
47+
)
48+
),
49+
required = Seq("countries")
50+
)
51+
52+
private val modelId = NonOpenAIModelId.gemini_2_0_flash
53+
54+
override protected def run: Future[_] =
55+
service
56+
.createChatCompletionWithJSON[JsObject](
57+
messages = messages,
58+
settings = CreateChatCompletionSettings(
59+
model = modelId,
60+
response_format_type = Some(ChatCompletionResponseFormatType.json_schema),
61+
jsonSchema = Some(
62+
JsonSchemaDef(
63+
name = "countries_response",
64+
strict = true,
65+
structure = jsonSchema
66+
)
67+
)
68+
)
69+
)
70+
.map(json => println(Json.prettyPrint(json)))
71+
}
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package io.cequence.openaiscala.examples.vertexai
1+
package io.cequence.openaiscala.examples.googlevertexai
22

33
import akka.stream.scaladsl.Sink
44
import io.cequence.openaiscala.domain.settings.CreateChatCompletionSettings
@@ -10,7 +10,7 @@ import io.cequence.openaiscala.service.StreamedServiceTypes.OpenAIChatCompletion
1010
import scala.concurrent.Future
1111

1212
// requires `openai-scala-google-vertexai-client` as a dependency and `VERTEXAI_LOCATION` and `VERTEXAI_PROJECT_ID` environments variable to be set
13-
object VertexAICreateChatCompletionStreamedWithOpenAIAdapter
13+
object GoogleVertexAICreateChatCompletionStreamedWithOpenAIAdapter
1414
extends ExampleBase[OpenAIChatCompletionService] {
1515

1616
override val service: OpenAIChatCompletionStreamedService = ChatCompletionProvider.vertexAI
@@ -33,9 +33,8 @@ object VertexAICreateChatCompletionStreamedWithOpenAIAdapter
3333
)
3434
)
3535
.runWith(
36-
Sink.foreach { completion =>
37-
val content = completion.choices.headOption.flatMap(_.delta.content)
38-
print(content.getOrElse(""))
36+
Sink.foreach { response =>
37+
print(response.contentHead.getOrElse(""))
3938
}
4039
)
4140
}

0 commit comments

Comments
 (0)