Skip to content

Commit 88e0fe3

Browse files
authored
For OpenAI, make all fields required in JSON schema. (#213)
We use union type (with `null`) to mark fields optional. Background: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas#all-fields-must-be-required
1 parent d1cdd56 commit 88e0fe3

File tree

4 files changed

+59
-14
lines changed

4 files changed

+59
-14
lines changed

src/base/json_schema.rs

+34-12
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,19 @@ use schemars::schema::{
33
ArrayValidation, InstanceType, Metadata, ObjectValidation, Schema, SchemaObject, SingleOrVec,
44
};
55

6+
pub struct ToJsonSchemaOptions {
7+
/// If true, mark all fields as required.
8+
/// Use union type (with `null`) for optional fields instead.
9+
/// Models like OpenAI will reject the schema if a field is not required.
10+
pub fields_always_required: bool,
11+
}
12+
613
pub trait ToJsonSchema {
7-
fn to_json_schema(&self) -> SchemaObject;
14+
fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject;
815
}
916

1017
impl ToJsonSchema for schema::BasicValueType {
11-
fn to_json_schema(&self) -> SchemaObject {
18+
fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject {
1219
let mut schema = SchemaObject::default();
1320
match self {
1421
schema::BasicValueType::Str => {
@@ -59,7 +66,7 @@ impl ToJsonSchema for schema::BasicValueType {
5966
schema.instance_type = Some(SingleOrVec::Single(Box::new(InstanceType::Array)));
6067
schema.array = Some(Box::new(ArrayValidation {
6168
items: Some(SingleOrVec::Single(Box::new(
62-
s.element_type.to_json_schema().into(),
69+
s.element_type.to_json_schema(options).into(),
6370
))),
6471
min_items: s.dimension.and_then(|d| u32::try_from(d).ok()),
6572
max_items: s.dimension.and_then(|d| u32::try_from(d).ok()),
@@ -72,7 +79,7 @@ impl ToJsonSchema for schema::BasicValueType {
7279
}
7380

7481
impl ToJsonSchema for schema::StructSchema {
75-
fn to_json_schema(&self) -> SchemaObject {
82+
fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject {
7683
SchemaObject {
7784
metadata: Some(Box::new(Metadata {
7885
description: self.description.as_ref().map(|s| s.to_string()),
@@ -83,12 +90,25 @@ impl ToJsonSchema for schema::StructSchema {
8390
properties: self
8491
.fields
8592
.iter()
86-
.map(|f| (f.name.to_string(), f.value_type.to_json_schema().into()))
93+
.map(|f| {
94+
let mut schema = f.value_type.to_json_schema(options);
95+
if options.fields_always_required && f.value_type.nullable {
96+
if let Some(instance_type) = &mut schema.instance_type {
97+
let mut types = match instance_type {
98+
SingleOrVec::Single(t) => vec![**t],
99+
SingleOrVec::Vec(t) => std::mem::take(t),
100+
};
101+
types.push(InstanceType::Null);
102+
*instance_type = SingleOrVec::Vec(types);
103+
}
104+
}
105+
(f.name.to_string(), schema.into())
106+
})
87107
.collect(),
88108
required: self
89109
.fields
90110
.iter()
91-
.filter(|&f| (!f.value_type.nullable))
111+
.filter(|&f| (options.fields_always_required || !f.value_type.nullable))
92112
.map(|f| f.name.to_string())
93113
.collect(),
94114
additional_properties: Some(Schema::Bool(false).into()),
@@ -100,14 +120,16 @@ impl ToJsonSchema for schema::StructSchema {
100120
}
101121

102122
impl ToJsonSchema for schema::ValueType {
103-
fn to_json_schema(&self) -> SchemaObject {
123+
fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject {
104124
match self {
105-
schema::ValueType::Basic(b) => b.to_json_schema(),
106-
schema::ValueType::Struct(s) => s.to_json_schema(),
125+
schema::ValueType::Basic(b) => b.to_json_schema(options),
126+
schema::ValueType::Struct(s) => s.to_json_schema(options),
107127
schema::ValueType::Collection(c) => SchemaObject {
108128
instance_type: Some(SingleOrVec::Single(Box::new(InstanceType::Array))),
109129
array: Some(Box::new(ArrayValidation {
110-
items: Some(SingleOrVec::Single(Box::new(c.row.to_json_schema().into()))),
130+
items: Some(SingleOrVec::Single(Box::new(
131+
c.row.to_json_schema(options).into(),
132+
))),
111133
..Default::default()
112134
})),
113135
..Default::default()
@@ -117,7 +139,7 @@ impl ToJsonSchema for schema::ValueType {
117139
}
118140

119141
impl ToJsonSchema for schema::EnrichedValueType {
120-
fn to_json_schema(&self) -> SchemaObject {
121-
self.typ.to_json_schema()
142+
fn to_json_schema(&self, options: &ToJsonSchemaOptions) -> SchemaObject {
143+
self.typ.to_json_schema(options)
122144
}
123145
}

src/llm/mod.rs

+15
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ use async_trait::async_trait;
55
use schemars::schema::SchemaObject;
66
use serde::{Deserialize, Serialize};
77

8+
use crate::base::json_schema::ToJsonSchemaOptions;
9+
810
#[derive(Debug, Clone, Serialize, Deserialize)]
911
pub enum LlmApiType {
1012
Ollama,
@@ -44,6 +46,19 @@ pub trait LlmGenerationClient: Send + Sync {
4446
&self,
4547
request: LlmGenerateRequest<'req>,
4648
) -> Result<LlmGenerateResponse>;
49+
50+
/// If true, the LLM only accepts a JSON schema with all fields required.
51+
/// This is a limitation of LLM models such as OpenAI.
52+
/// Otherwise, the LLM will accept a JSON schema with optional fields.
53+
fn json_schema_fields_always_required(&self) -> bool {
54+
false
55+
}
56+
57+
fn to_json_schema_options(&self) -> ToJsonSchemaOptions {
58+
ToJsonSchemaOptions {
59+
fields_always_required: self.json_schema_fields_always_required(),
60+
}
61+
}
4762
}
4863

4964
mod ollama;

src/llm/openai.rs

+4
Original file line numberDiff line numberDiff line change
@@ -97,4 +97,8 @@ impl LlmGenerationClient for Client {
9797

9898
Ok(super::LlmGenerateResponse { text })
9999
}
100+
101+
fn json_schema_fields_always_required(&self) -> bool {
102+
true
103+
}
100104
}

src/ops/functions/extract_by_llm.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,14 @@ Output only the JSON without any additional messages or explanations."
4747

4848
impl Executor {
4949
async fn new(spec: Spec, args: Args) -> Result<Self> {
50+
let client = new_llm_generation_client(spec.llm_spec).await?;
51+
let output_json_schema = spec
52+
.output_type
53+
.to_json_schema(&client.to_json_schema_options());
5054
Ok(Self {
5155
args,
52-
client: new_llm_generation_client(spec.llm_spec).await?,
53-
output_json_schema: spec.output_type.to_json_schema(),
56+
client,
57+
output_json_schema,
5458
output_type: spec.output_type,
5559
system_prompt: get_system_prompt(&spec.instruction),
5660
})

0 commit comments

Comments
 (0)