diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 2f233e2e7477f..35a05ddeb6a1a 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -8,6 +8,7 @@ #include #include #include +#include using json = nlohmann::ordered_json; @@ -392,10 +393,27 @@ class SchemaConverter { std::function _fetch_json; bool _dotall; std::map _rules; - std::unordered_map _refs; std::unordered_set _refs_being_resolved; std::vector _errors; std::vector _warnings; + std::unordered_map _external_refs; + std::vector _ref_context; + + struct with_context { + SchemaConverter * _this; + const json * _target; + with_context(SchemaConverter * _this, const json * target) : _this(_this), _target(target) { + if (target) { + _this->_ref_context.push_back(*target); + } + } + ~with_context() { + if (_target) { + GGML_ASSERT(_this->_ref_context.back() == *_target); // should not have been modified + _this->_ref_context.pop_back(); + } + } + }; std::string _add_rule(const std::string & name, const std::string & rule) { std::string esc_name = regex_replace(name, INVALID_RULE_CHARS_RE, "-"); @@ -683,17 +701,6 @@ class SchemaConverter { return out.str(); } - std::string _resolve_ref(const std::string & ref) { - std::string ref_name = ref.substr(ref.find_last_of('/') + 1); - if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) { - _refs_being_resolved.insert(ref); - json resolved = _refs[ref]; - ref_name = visit(resolved, ref_name); - _refs_being_resolved.erase(ref); - } - return ref_name; - } - std::string _build_object_rule( const std::vector> & properties, const std::unordered_set & required, @@ -815,78 +822,70 @@ class SchemaConverter { _rules["space"] = SPACE_RULE; } - void resolve_refs(json & schema, const std::string & url) { - /* - * Resolves all $ref fields in the given schema, fetching any remote schemas, - * replacing each $ref with absolute reference URL and populates _refs with the - * respective referenced (sub)schema dictionaries. - */ - std::function visit_refs = [&](json & n) { - if (n.is_array()) { - for (auto & x : n) { - visit_refs(x); - } - } else if (n.is_object()) { - if (n.contains("$ref")) { - std::string ref = n["$ref"]; - if (_refs.find(ref) == _refs.end()) { - json target; - if (ref.find("https://") == 0) { - std::string base_url = ref.substr(0, ref.find('#')); - auto it = _refs.find(base_url); - if (it != _refs.end()) { - target = it->second; - } else { - // Fetch the referenced schema and resolve its refs - auto referenced = _fetch_json(ref); - resolve_refs(referenced, base_url); - _refs[base_url] = referenced; - } - if (ref.find('#') == std::string::npos || ref.substr(ref.find('#') + 1).empty()) { - return; - } - } else if (ref.find("#/") == 0) { - target = schema; - n["$ref"] = url + ref; - ref = url + ref; - } else { - _errors.push_back("Unsupported ref: " + ref); - return; - } - std::string pointer = ref.substr(ref.find('#') + 1); - std::vector tokens = split(pointer, "/"); - for (size_t i = 1; i < tokens.size(); ++i) { - std::string sel = tokens[i]; - if (target.is_null() || !target.contains(sel)) { - _errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump()); - return; - } - target = target[sel]; - } - _refs[ref] = target; - } - } else { - for (auto & kv : n.items()) { - visit_refs(kv.value()); - } - } - } - }; - - visit_refs(schema); - } + // const std::unordered_map & get_refs() const { + // return _refs; + // } std::string _generate_constant_rule(const json & value) { return format_literal(value.dump()); } + std::pair _resolve_ref(const std::string & ref) { + auto parts = split(ref, "#"); + if (parts.size() != 2) { + _errors.push_back("Unsupported ref: " + ref); + return {json(), false}; + } + const auto & url = parts[0]; + json target; + bool is_local = url.empty(); + if (is_local) { + if (_ref_context.empty()) { + _errors.push_back("Error resolving ref " + ref + ": no context"); + return {json(), false}; + } + target = _ref_context.back(); + } else { + auto it = _external_refs.find(url); + if (it != _external_refs.end()) { + target = it->second; + } else { + // Fetch the referenced schema and resolve its refs + auto referenced = _fetch_json(url); + // resolve_refs(referenced, url); + _external_refs[url] = referenced; + } + } + std::vector tokens = split(parts[1], "/"); + for (size_t i = 1; i < tokens.size(); ++i) { + std::string sel = tokens[i]; + if (target.is_null() || !target.contains(sel)) { + _errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump()); + return json(); + } + target = target[sel]; + } + return {target, is_local}; + } + std::string visit(const json & schema, const std::string & name) { json schema_type = schema.contains("type") ? schema["type"] : json(); std::string schema_format = schema.contains("format") ? schema["format"].get() : ""; std::string rule_name = is_reserved_name(name) ? name + "-" : name.empty() ? "root" : name; - if (schema.contains("$ref")) { - return _add_rule(rule_name, _resolve_ref(schema["$ref"])); + with_context wc(this, _ref_context.empty() ? &schema : nullptr); + + if (schema.contains("$ref") && schema["$ref"].is_string()) { + const auto & ref = schema["$ref"].get(); + auto pair = _resolve_ref(ref); + auto target = pair.first; + auto is_local = pair.second; + if (target.is_null()) { + return ""; + } + std::cout << target.dump(4) << std::endl; + with_context wc(this, is_local ? nullptr : &target); + return visit(target, name); } else if (schema.contains("oneOf") || schema.contains("anyOf")) { std::vector alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get>() : schema["anyOf"].get>(); return _add_rule(rule_name, _generate_union_rule(name, alt_schemas)); @@ -932,8 +931,9 @@ class SchemaConverter { std::vector> properties; std::string hybrid_name = name; std::function add_component = [&](const json & comp_schema, bool is_required) { - if (comp_schema.contains("$ref")) { - add_component(_refs[comp_schema["$ref"]], is_required); + if (comp_schema.contains("$ref") && comp_schema["$ref"].is_string()) { + auto target = _resolve_ref(schema["$ref"].get()); + add_component(target, is_required); } else if (comp_schema.contains("properties")) { for (const auto & prop : comp_schema["properties"].items()) { properties.emplace_back(prop.key(), prop.value()); @@ -1038,7 +1038,11 @@ class SchemaConverter { std::string json_schema_to_grammar(const json & schema) { SchemaConverter converter([](const std::string &) { return json::object(); }, /* dotall= */ false); auto copy = schema; - converter.resolve_refs(copy, "input"); + // converter.resolve_refs(copy, "input"); + std::cout << copy.dump(4) << std::endl; + // for (const auto & [n, j] : converter.get_refs()) { + // std::cout << "REF: " << n << " -> " << j.dump(4) << "\n"; + // } converter.visit(copy, ""); converter.check_errors(); return converter.format_grammar(); diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 92f6e3d47bae7..96415cb1ed2e1 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -243,8 +243,9 @@ def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern): self._rules = { 'space': SPACE_RULE, } - self._refs = {} - self._refs_being_resolved = set() + self._external_refs = {} + # self._refs_being_resolved = set() + self._ref_context = [] def _format_literal(self, literal): escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( @@ -332,51 +333,6 @@ def _add_rule(self, name, rule): self._rules[key] = rule return key - def resolve_refs(self, schema: dict, url: str): - ''' - Resolves all $ref fields in the given schema, fetching any remote schemas, - replacing $ref with absolute reference URL and populating self._refs with the - respective referenced (sub)schema dictionaries. - ''' - def visit(n: dict): - if isinstance(n, list): - return [visit(x) for x in n] - elif isinstance(n, dict): - ref = n.get('$ref') - if ref is not None and ref not in self._refs: - if ref.startswith('https://'): - assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)' - import requests - - frag_split = ref.split('#') - base_url = frag_split[0] - - target = self._refs.get(base_url) - if target is None: - target = self.resolve_refs(requests.get(ref).json(), base_url) - self._refs[base_url] = target - - if len(frag_split) == 1 or frag_split[-1] == '': - return target - elif ref.startswith('#/'): - target = schema - ref = f'{url}{ref}' - n['$ref'] = ref - else: - raise ValueError(f'Unsupported ref {ref}') - - for sel in ref.split('#')[-1].split('/')[1:]: - assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' - target = target[sel] - - self._refs[ref] = target - else: - for v in n.values(): - visit(v) - - return n - return visit(schema) - def _generate_union_rule(self, name, alt_schemas): return ' | '.join(( self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}') @@ -541,18 +497,33 @@ def join_seq(): else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space") - def _resolve_ref(self, ref): - ref_name = ref.split('/')[-1] - if ref_name not in self._rules and ref not in self._refs_being_resolved: - self._refs_being_resolved.add(ref) - resolved = self._refs[ref] - ref_name = self.visit(resolved, ref_name) - self._refs_being_resolved.remove(ref) - return ref_name - def _generate_constant_rule(self, value): return self._format_literal(json.dumps(value)) + def _resolve_ref(self, ref): + parts = ref.split('#') + assert len(parts) == 2, f'Unsupported ref: {ref}' + url = parts[0] + is_local = url == '' + if is_local: + assert self._refs_being_resolved, 'Error resolving ref {ref}: no context' + target = self._refs_being_resolved[-1] + else: + if url in self._refs: + target = self._refs[url] + else: + assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)' + import requests + referenced = requests.get(url).json() + self._refs[url] = referenced + target = referenced + + for sel in parts[1].split('/')[1:]: + assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}' + target = target[sel] + + return target + def visit(self, schema, name): schema_type = schema.get('type') schema_format = schema.get('format') diff --git a/tests/test-grammar-integration.cpp b/tests/test-grammar-integration.cpp index 0e21dc7959943..32b7902e366c1 100644 --- a/tests/test-grammar-integration.cpp +++ b/tests/test-grammar-integration.cpp @@ -1266,6 +1266,47 @@ static void test_json_schema() { // R"""({"productId": 1, "productName": "A green door", "price": 12.50, "tags": ["home", "green", "home"]})""", } ); + + test_schema( + "refs", + // Schema + R"""({ + "type": "array", + "minItems": 15, + "maxItems": 15, + "items": { "$ref": "#/$defs/TALK" }, + + "$defs": { + "characters": { "enum": ["Biff", "Alice"] }, + "emotes": { "enum": ["EXCLAMATION", "CONFUSION", "CHEERFUL", "LOVE", "ANGRY"] }, + + "TALK": { + "type": "object", + "required": [ "character", "emote", "dialog" ], + "properties": { + "character": { "$ref": "#/$defs/characters" }, + "emote": { "$ref": "#/$defs/emotes" }, + "dialog": { + "type": "string", + "minLength": 1, + "maxLength": 200 + } + } + } + } + })""", + // Passing strings + { + R"""({ + "character": "Alice", + "emote": "EXCLAMATION", + "dialog": "Hello, world!" + })""", + }, + // Failing strings + { + } + ); } int main() { diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index 3aaa11833de57..818d41988007c 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -344,6 +344,47 @@ static void test_all(const std::string & lang, std::function