Skip to content

Commit d63c953

Browse files
committed
json: fix nested $refs & allow mix of properties & anyOf (ggml-org#8073)
1 parent cb0b06a commit d63c953

12 files changed

+427
-374
lines changed

Diff for: common/json-schema-to-grammar.cpp

+132-128
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <unordered_map>
99
#include <unordered_set>
1010
#include <vector>
11+
#include <iostream>
1112

1213
using json = nlohmann::ordered_json;
1314

@@ -392,10 +393,10 @@ class SchemaConverter {
392393
std::function<json(const std::string &)> _fetch_json;
393394
bool _dotall;
394395
std::map<std::string, std::string> _rules;
395-
std::unordered_map<std::string, json> _refs;
396-
std::unordered_set<std::string> _refs_being_resolved;
397396
std::vector<std::string> _errors;
398397
std::vector<std::string> _warnings;
398+
std::unordered_map<std::string, json> _external_refs;
399+
std::vector<json> _ref_context;
399400

400401
std::string _add_rule(const std::string & name, const std::string & rule) {
401402
std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-");
@@ -683,17 +684,6 @@ class SchemaConverter {
683684
return out.str();
684685
}
685686

686-
std::string _resolve_ref(const std::string & ref) {
687-
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
688-
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
689-
_refs_being_resolved.insert(ref);
690-
json resolved = _refs[ref];
691-
ref_name = visit(resolved, ref_name);
692-
_refs_being_resolved.erase(ref);
693-
}
694-
return ref_name;
695-
}
696-
697687
std::string _build_object_rule(
698688
const std::vector<std::pair<std::string, json>> & properties,
699689
const std::unordered_set<std::string> & required,
@@ -815,78 +805,79 @@ class SchemaConverter {
815805
_rules["space"] = SPACE_RULE;
816806
}
817807

818-
void resolve_refs(json & schema, const std::string & url) {
819-
/*
820-
* Resolves all $ref fields in the given schema, fetching any remote schemas,
821-
* replacing each $ref with absolute reference URL and populates _refs with the
822-
* respective referenced (sub)schema dictionaries.
823-
*/
824-
std::function<void(json &)> visit_refs = [&](json & n) {
825-
if (n.is_array()) {
826-
for (auto & x : n) {
827-
visit_refs(x);
828-
}
829-
} else if (n.is_object()) {
830-
if (n.contains("$ref")) {
831-
std::string ref = n["$ref"];
832-
if (_refs.find(ref) == _refs.end()) {
833-
json target;
834-
if (ref.find("https://") == 0) {
835-
std::string base_url = ref.substr(0, ref.find('#'));
836-
auto it = _refs.find(base_url);
837-
if (it != _refs.end()) {
838-
target = it->second;
839-
} else {
840-
// Fetch the referenced schema and resolve its refs
841-
auto referenced = _fetch_json(ref);
842-
resolve_refs(referenced, base_url);
843-
_refs[base_url] = referenced;
844-
}
845-
if (ref.find('#') == std::string::npos || ref.substr(ref.find('#') + 1).empty()) {
846-
return;
847-
}
848-
} else if (ref.find("#/") == 0) {
849-
target = schema;
850-
n["$ref"] = url + ref;
851-
ref = url + ref;
852-
} else {
853-
_errors.push_back("Unsupported ref: " + ref);
854-
return;
855-
}
856-
std::string pointer = ref.substr(ref.find('#') + 1);
857-
std::vector<std::string> tokens = split(pointer, "/");
858-
for (size_t i = 1; i < tokens.size(); ++i) {
859-
std::string sel = tokens[i];
860-
if (target.is_null() || !target.contains(sel)) {
861-
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
862-
return;
863-
}
864-
target = target[sel];
865-
}
866-
_refs[ref] = target;
867-
}
868-
} else {
869-
for (auto & kv : n.items()) {
870-
visit_refs(kv.value());
871-
}
872-
}
873-
}
874-
};
875-
876-
visit_refs(schema);
877-
}
878-
879808
std::string _generate_constant_rule(const json & value) {
880809
return format_literal(value.dump());
881810
}
882811

812+
struct ResolvedRef {
813+
json target;
814+
std::string name;
815+
bool is_local;
816+
};
817+
818+
ResolvedRef _resolve_ref(const std::string & ref) {
819+
auto parts = split(ref, "#");
820+
if (parts.size() != 2) {
821+
_errors.push_back("Unsupported ref: " + ref);
822+
return {json(), "", false};
823+
}
824+
const auto & url = parts[0];
825+
json target;
826+
bool is_local = url.empty();
827+
if (is_local) {
828+
if (_ref_context.empty()) {
829+
_errors.push_back("Error resolving ref " + ref + ": no context");
830+
return {json(), "", false};
831+
}
832+
target = _ref_context.back();
833+
} else {
834+
auto it = _external_refs.find(url);
835+
if (it != _external_refs.end()) {
836+
target = it->second;
837+
} else {
838+
// Fetch the referenced schema and resolve its refs
839+
target = _fetch_json(url);
840+
_external_refs[url] = target;
841+
}
842+
}
843+
auto tokens = split(parts[1], "/");
844+
for (size_t i = 1; i < tokens.size(); ++i) {
845+
const auto & sel = tokens[i];
846+
if (target.is_null() || !target.contains(sel)) {
847+
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
848+
return {json(), "", false};
849+
}
850+
target = target[sel];
851+
}
852+
return {target, tokens.empty() ? "" : tokens[tokens.size() - 1], is_local};
853+
}
854+
883855
std::string visit(const json & schema, const std::string & name) {
884856
json schema_type = schema.contains("type") ? schema["type"] : json();
885857
std::string schema_format = schema.contains("format") ? schema["format"].get<std::string>() : "";
886858
std::string rule_name = is_reserved_name(name) ? name + "-" : name.empty() ? "root" : name;
887859

888-
if (schema.contains("$ref")) {
889-
return _add_rule(rule_name, _resolve_ref(schema["$ref"]));
860+
if (_ref_context.empty()) {
861+
_ref_context.push_back(schema);
862+
auto ret = visit(schema, name);
863+
_ref_context.pop_back();
864+
return ret;
865+
}
866+
867+
if (schema.contains("$ref") && schema["$ref"].is_string()) {
868+
const auto & ref = schema["$ref"].get<std::string>();
869+
auto resolved = _resolve_ref(ref);
870+
if (resolved.target.is_null()) {
871+
return "";
872+
}
873+
if (!resolved.is_local) {
874+
_ref_context.push_back(resolved.target);
875+
}
876+
auto ret = visit(resolved.target, (name.empty() || resolved.name.empty()) ? name : resolved.name);
877+
if (!resolved.is_local) {
878+
_ref_context.pop_back();
879+
}
880+
return ret;
890881
} else if (schema.contains("oneOf") || schema.contains("anyOf")) {
891882
std::vector<json> alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get<std::vector<json>>() : schema["anyOf"].get<std::vector<json>>();
892883
return _add_rule(rule_name, _generate_union_rule(name, alt_schemas));
@@ -906,55 +897,6 @@ class SchemaConverter {
906897
enum_values.push_back(_generate_constant_rule(v));
907898
}
908899
return _add_rule(rule_name, "(" + join(enum_values.begin(), enum_values.end(), " | ") + ") space");
909-
} else if ((schema_type.is_null() || schema_type == "object")
910-
&& (schema.contains("properties") ||
911-
(schema.contains("additionalProperties") && schema["additionalProperties"] != true))) {
912-
std::unordered_set<std::string> required;
913-
if (schema.contains("required") && schema["required"].is_array()) {
914-
for (const auto & item : schema["required"]) {
915-
if (item.is_string()) {
916-
required.insert(item.get<std::string>());
917-
}
918-
}
919-
}
920-
std::vector<std::pair<std::string, json>> properties;
921-
if (schema.contains("properties")) {
922-
for (const auto & prop : schema["properties"].items()) {
923-
properties.emplace_back(prop.key(), prop.value());
924-
}
925-
}
926-
return _add_rule(rule_name,
927-
_build_object_rule(
928-
properties, required, name,
929-
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
930-
} else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) {
931-
std::unordered_set<std::string> required;
932-
std::vector<std::pair<std::string, json>> properties;
933-
std::string hybrid_name = name;
934-
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
935-
if (comp_schema.contains("$ref")) {
936-
add_component(_refs[comp_schema["$ref"]], is_required);
937-
} else if (comp_schema.contains("properties")) {
938-
for (const auto & prop : comp_schema["properties"].items()) {
939-
properties.emplace_back(prop.key(), prop.value());
940-
if (is_required) {
941-
required.insert(prop.key());
942-
}
943-
}
944-
} else {
945-
// todo warning
946-
}
947-
};
948-
for (auto & t : schema["allOf"]) {
949-
if (t.contains("anyOf")) {
950-
for (auto & tt : t["anyOf"]) {
951-
add_component(tt, false);
952-
}
953-
} else {
954-
add_component(t, true);
955-
}
956-
}
957-
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
958900
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
959901
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];
960902
if (items.is_array()) {
@@ -1005,8 +947,71 @@ class SchemaConverter {
1005947
_build_min_max_int(min_value, max_value, out);
1006948
out << ") space";
1007949
return _add_rule(rule_name, out.str());
1008-
} else if (schema.empty() || schema_type == "object") {
1009-
return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object")));
950+
} else if ((schema_type.is_null() || schema_type == "object")) {
951+
std::unordered_set<std::string> required;
952+
std::vector<std::pair<std::string, json>> properties;
953+
auto is_explicit_object = schema_type == "object";
954+
json additional_properties;
955+
if (schema.contains("additionalProperties")) {
956+
is_explicit_object = true;
957+
additional_properties = schema["additionalProperties"];
958+
}
959+
if (schema.contains("properties") && schema["properties"].is_object()) {
960+
is_explicit_object = true;
961+
for (const auto & prop : schema["properties"].items()) {
962+
if (prop.value().is_object()) {
963+
properties.emplace_back(prop.key(), prop.value());
964+
}
965+
}
966+
}
967+
if (schema.contains("required") && schema["required"].is_array()) {
968+
for (const auto & item : schema["required"]) {
969+
if (item.is_string()) {
970+
required.insert(item.get<std::string>());
971+
}
972+
}
973+
}
974+
if (schema.contains("allOf") && schema["allOf"].is_array()) {
975+
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
976+
if (comp_schema.contains("$ref") && comp_schema["$ref"].is_string()) {
977+
auto resolved = _resolve_ref(comp_schema["$ref"].get<std::string>());
978+
add_component(resolved.target, is_required);
979+
} else if (comp_schema.contains("properties")) {
980+
for (const auto & prop : comp_schema["properties"].items()) {
981+
properties.emplace_back(prop.key(), prop.value());
982+
if (is_required) {
983+
required.insert(prop.key());
984+
}
985+
}
986+
if (comp_schema.contains("additionalProperties")) {
987+
if (additional_properties.is_null()) {
988+
additional_properties = comp_schema["additionalProperties"];
989+
} else if (additional_properties != comp_schema["additionalProperties"]) {
990+
_warnings.push_back("Inconsistent additionalProperties in allOf");
991+
}
992+
}
993+
} else {
994+
// todo warning
995+
}
996+
};
997+
for (auto & t : schema["allOf"]) {
998+
if (t.contains("anyOf")) {
999+
for (auto & tt : t["anyOf"]) {
1000+
add_component(tt, false);
1001+
}
1002+
} else {
1003+
add_component(t, true);
1004+
}
1005+
}
1006+
}
1007+
if (properties.empty() && (additional_properties == true || additional_properties.is_null())) {
1008+
return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object")));
1009+
}
1010+
auto default_additional_properties = is_explicit_object ? json() : json(false);
1011+
return _add_rule(rule_name,
1012+
_build_object_rule(
1013+
properties, required, name,
1014+
additional_properties.is_null() ? default_additional_properties : additional_properties));
10101015
} else {
10111016
if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get<std::string>()) == PRIMITIVE_RULES.end()) {
10121017
_errors.push_back("Unrecognized schema: " + schema.dump());
@@ -1038,7 +1043,6 @@ class SchemaConverter {
10381043
std::string json_schema_to_grammar(const json & schema) {
10391044
SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false);
10401045
auto copy = schema;
1041-
converter.resolve_refs(copy, "input");
10421046
converter.visit(copy, "");
10431047
converter.check_errors();
10441048
return converter.format_grammar();

0 commit comments

Comments
 (0)