Skip to content

Commit cf68313

Browse files
committed
Google Gemini - json schema support for OpenAI adapter
1 parent 5da31b4 commit cf68313

File tree

6 files changed

+185
-9
lines changed

6 files changed

+185
-9
lines changed

google-gemini-client/src/main/scala/io/cequence/openaiscala/gemini/service/impl/OpenAIGeminiChatCompletionService.scala

+114-2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ import io.cequence.openaiscala.service.{
4141
}
4242

4343
import scala.concurrent.{ExecutionContext, Future}
44+
import io.cequence.openaiscala.domain.settings.ChatCompletionResponseFormatType
45+
import io.cequence.openaiscala.domain.JsonSchema
46+
import io.cequence.openaiscala.gemini.domain.Schema
47+
import com.typesafe.scalalogging.Logger
48+
import io.cequence.openaiscala.gemini.domain.SchemaType
49+
import org.slf4j.LoggerFactory
50+
51+
import scala.collection.immutable.Traversable
4452

4553
private[service] class OpenAIGeminiChatCompletionService(
4654
underlying: GeminiService
@@ -49,6 +57,8 @@ private[service] class OpenAIGeminiChatCompletionService(
4957
) extends OpenAIChatCompletionService
5058
with OpenAIChatCompletionStreamedServiceExtra {
5159

60+
protected val logger: Logger = Logger(LoggerFactory.getLogger(this.getClass))
61+
5262
override def createChatCompletion(
5363
messages: Seq[BaseMessage],
5464
settings: CreateChatCompletionSettings
@@ -185,7 +195,31 @@ private[service] class OpenAIGeminiChatCompletionService(
185195
private def toGeminiSettings(
186196
settings: CreateChatCompletionSettings,
187197
systemMessage: Option[BaseMessage]
188-
): GenerateContentSettings =
198+
): GenerateContentSettings = {
199+
200+
// handle json schema
201+
val responseFormat =
202+
settings.response_format_type.getOrElse(ChatCompletionResponseFormatType.text)
203+
204+
val jsonSchema =
205+
if (
206+
responseFormat == ChatCompletionResponseFormatType.json_schema && settings.jsonSchema.isDefined
207+
) {
208+
settings.jsonSchema.get.structure match {
209+
case Left(schema) =>
210+
Some(toGeminiJSONSchema(schema))
211+
case Right(_) =>
212+
logger.warn(
213+
"Map-like legacy JSON schema is not supported for conversion to Gemini schema."
214+
)
215+
None
216+
}
217+
} else
218+
None
219+
220+
// check for unsupported fields
221+
checkNotSupported(settings)
222+
189223
GenerateContentSettings(
190224
model = settings.model,
191225
tools = None, // TODO
@@ -196,7 +230,7 @@ private[service] class OpenAIGeminiChatCompletionService(
196230
GenerationConfig(
197231
stopSequences = (if (settings.stop.nonEmpty) Some(settings.stop) else None),
198232
responseMimeType = None,
199-
responseSchema = None, // TODO: support JSON!
233+
responseSchema = jsonSchema,
200234
responseModalities = None,
201235
candidateCount = settings.n,
202236
maxOutputTokens = settings.max_tokens,
@@ -214,6 +248,84 @@ private[service] class OpenAIGeminiChatCompletionService(
214248
),
215249
cachedContent = None
216250
)
251+
}
252+
253+
private def checkNotSupported(
254+
settings: CreateChatCompletionSettings
255+
) = {
256+
def notSupported(
257+
field: CreateChatCompletionSettings => Option[_],
258+
fieldName: String
259+
): Unit =
260+
field(settings).foreach { _ =>
261+
logger.warn(s"Field $fieldName is not yet supported for Gemini. Skipping...")
262+
}
263+
264+
def notSupportedCollection(
265+
field: CreateChatCompletionSettings => Traversable[_],
266+
fieldName: String
267+
): Unit =
268+
if (field(settings).nonEmpty) {
269+
logger.warn(s"Field $fieldName is not supported for Gemini. Skipping...")
270+
}
271+
272+
notSupported(_.reasoning_effort, "reasoning_effort")
273+
notSupported(_.service_tier, "service_tier")
274+
notSupported(_.parallel_tool_calls, "parallel_tool_calls")
275+
notSupportedCollection(_.metadata, "metadata")
276+
notSupportedCollection(_.logit_bias, "logit_bias")
277+
notSupported(_.user, "user")
278+
notSupported(_.store, "store")
279+
}
280+
281+
private def toGeminiJSONSchema(
282+
jsonSchema: JsonSchema
283+
): Schema = jsonSchema match {
284+
case JsonSchema.String(description, enumVals) =>
285+
Schema(
286+
`type` = SchemaType.STRING,
287+
description = description,
288+
`enum` = Some(enumVals)
289+
)
290+
291+
case JsonSchema.Number(description) =>
292+
Schema(
293+
`type` = SchemaType.NUMBER,
294+
description = description
295+
)
296+
297+
case JsonSchema.Integer(description) =>
298+
Schema(
299+
`type` = SchemaType.INTEGER,
300+
description = description
301+
)
302+
303+
case JsonSchema.Boolean(description) =>
304+
Schema(
305+
`type` = SchemaType.BOOLEAN,
306+
description = description
307+
)
308+
309+
case JsonSchema.Object(properties, required) =>
310+
Schema(
311+
`type` = SchemaType.OBJECT,
312+
properties = Some(
313+
properties.map { case (key, jsonSchema) =>
314+
key -> toGeminiJSONSchema(jsonSchema)
315+
}.toMap
316+
),
317+
required = Some(required)
318+
)
319+
320+
case JsonSchema.Array(items) =>
321+
Schema(
322+
`type` = SchemaType.ARRAY,
323+
items = Some(toGeminiJSONSchema(items))
324+
)
325+
326+
case _ =>
327+
throw new OpenAIScalaClientException(s"Unsupported JSON schema type for Gemini.")
328+
}
217329

218330
private def toOpenAIResponse(
219331
response: GenerateContentResponse

openai-client/src/test/scala/io/cequence/openaiscala/MessageJsonSpec.scala

-3
Original file line numberDiff line numberDiff line change
@@ -221,11 +221,8 @@ class MessageJsonSpec extends Matchers with AnyWordSpecLike {
221221
val json = toJson(message)
222222
val jsonKeys = json.keySet
223223

224-
println(toJsonObject(message))
225-
226224
val messages2 = AssistantMessage(content)
227225
toJson(messages2)
228-
println(toJsonObject(messages2))
229226

230227
// json shouldNot be(json2)
231228

openai-core/src/main/scala/io/cequence/openaiscala/service/OpenAIChatCompletionExtra.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,9 @@ object OpenAIChatCompletionExtra {
9191
failureMessage = s"${taskNameForLoggingFinal.capitalize} failed."
9292
)
9393
.map { response =>
94-
val content = response.choices.head.message.content
94+
val content = response.contentHead
9595
val contentTrimmed = content.stripPrefix("```json").stripSuffix("```").trim
96-
val contentJson = contentTrimmed.dropWhile(_ != '{')
96+
val contentJson = contentTrimmed.dropWhile(char => char != '{' && char != '[')
9797
val json = parseJson(contentJson)
9898

9999
logger.debug(

openai-examples/src/main/scala-2/io/cequence/openaiscala/examples/CreateChatCompletionJsonForCaseClass.scala

+3-1
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@ object CreateChatCompletionJsonForCaseClass extends Example with JsonSchemaRefle
1515
case class Country(
1616
country: String,
1717
capital: String,
18-
populationMil: Double
18+
populationMil: Int,
19+
ratioOfMenToWomen: Double
1920
)
21+
2022
case class CapitalsResponse(capitals: Seq[Country])
2123

2224
// JSON format and schema

openai-examples/src/main/scala/io/cequence/openaiscala/examples/nonopenai/AnthropicCreateChatCompletionCachedWithOpenAIAdapter.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,6 @@ object AnthropicCreateChatCompletionCachedWithOpenAIAdapter
3131
) // this is how we pass it through the adapter
3232
)
3333
.map { content =>
34-
println(content.choices.headOption.map(_.message.content).getOrElse("N/A"))
34+
println(content.contentHead)
3535
}
3636
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package io.cequence.openaiscala.examples.nonopenai
2+
3+
import io.cequence.openaiscala.domain._
4+
import io.cequence.openaiscala.domain.settings.{ChatCompletionResponseFormatType, CreateChatCompletionSettings, JsonSchemaDef}
5+
import io.cequence.openaiscala.examples.ExampleBase
6+
import io.cequence.openaiscala.examples.fixtures.TestFixtures
7+
import io.cequence.openaiscala.gemini.service.GeminiServiceFactory
8+
import io.cequence.openaiscala.service.OpenAIChatCompletionService
9+
import io.cequence.openaiscala.service.OpenAIChatCompletionExtra._
10+
import play.api.libs.json.{JsArray, JsObject, Json}
11+
12+
import scala.concurrent.Future
13+
14+
/**
15+
* Requires `GOOGLE_API_KEY` environment variable to be set.
16+
*/
17+
object GoogleGeminiCreateChatCompletionJSONWithOpenAIAdapter
18+
extends ExampleBase[OpenAIChatCompletionService] with TestFixtures {
19+
20+
override val service: OpenAIChatCompletionService = GeminiServiceFactory.asOpenAI()
21+
22+
private val messages = Seq(
23+
SystemMessage("You are an expert geographer"),
24+
UserMessage("List all Asian countries in the prescribed JSON format.")
25+
)
26+
27+
private val jsonSchema = JsonSchema.Object(
28+
properties = Seq(
29+
"countries" -> JsonSchema.Array(
30+
JsonSchema.Object(
31+
properties = Seq(
32+
"country" -> JsonSchema.String(),
33+
"capital" -> JsonSchema.String(),
34+
"countrySize" -> JsonSchema.String(
35+
`enum` = Seq("small", "medium", "large")
36+
),
37+
"commonwealthMember" -> JsonSchema.Boolean(),
38+
"populationMil" -> JsonSchema.Integer(),
39+
"ratioOfMenToWomen" -> JsonSchema.Number(),
40+
),
41+
required = Seq("country", "capital", "countrySize", "commonwealthMember", "populationMil", "ratioOfMenToWomen")
42+
)
43+
)
44+
),
45+
required = Seq("countries")
46+
)
47+
48+
private val modelId = NonOpenAIModelId.gemini_2_0_flash
49+
50+
override protected def run: Future[_] =
51+
service
52+
.createChatCompletionWithJSON[JsObject](
53+
messages = messages,
54+
settings = CreateChatCompletionSettings(
55+
model = modelId,
56+
response_format_type = Some(ChatCompletionResponseFormatType.json_schema),
57+
jsonSchema = Some(JsonSchemaDef(
58+
name = "countries_response",
59+
strict = true,
60+
structure = jsonSchema
61+
))
62+
)
63+
)
64+
.map(json => println(Json.prettyPrint(json)))
65+
}

0 commit comments

Comments
 (0)