Skip to content

Commit

Permalink
json: improved repetitions & builtin rule deps
Browse files Browse the repository at this point in the history
  • Loading branch information
ochafik committed Apr 4, 2024
1 parent 8451cdb commit 375f85d
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 71 deletions.
107 changes: 64 additions & 43 deletions examples/json_schema_to_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,57 @@
import sys
from typing import Any, Dict, List, Set, Tuple, Union

def _build_repetition(content, up_to_n):
# return ' '.join([content] * n)
if up_to_n == 0:
return ''
return f'({content}{" " + _build_repetition(content, up_to_n-1) if up_to_n > 1 else ""})?'

class BuiltinRule:
def __init__(self, content: str, deps: list[str] = None):
self.content = content
self.deps = deps or []

def __str__(self):
assert false

_up_to_15_digits = _build_repetition('[0-9]', 15)

# whitespace is constrained to a single space char to prevent model "running away" in
# whitespace. Also maybe improves generation quality?
SPACE_RULE = '" "?'

PRIMITIVE_RULES = {
'boolean': '("true" | "false") space',
'decimal-part': '[0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] [0-9]?)?)?)?)?)?)?)?)?)?',
'integral-part': '[0-9] | [1-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] ([0-9] [0-9]?)?)?)?)?)?)?)?)?)?',

# 'number': '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
# 'integer': '("-"? ([0-9] | [1-9] [0-9]*)) space',
'number': '("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space',
'integer': '("-"? integral-part) space',
'value' : 'object | array | string | number | boolean',
'object' : '"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space',
'array' : '"[" space ( value ("," space value)* )? "]" space',
'uuid' : '"\\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + ' "\\"" space',
'string': r''' "\"" (
'boolean': BuiltinRule('("true" | "false") space', []),
'decimal-part': BuiltinRule('[0-9] ' + _up_to_15_digits, []),
'integral-part': BuiltinRule('[0-9] | [1-9] ' + _up_to_15_digits, []),
'number': BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
'integer': BuiltinRule('("-"? integral-part) space', ['integral-part']),
'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
'uuid' : BuiltinRule('"\\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + ' "\\"" space', []),
'string': BuiltinRule(r''' "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space''',
'null': '"null" space',
)* "\"" space''', []),
'null': BuiltinRule('"null" space', []),
}
OBJECT_RULE_NAMES = ['object', 'array', 'string', 'integral-part', 'decimal-part', 'number', 'boolean', 'null', 'value']

# TODO: support "uri", "email" string formats
DATE_RULES = {
'date' : '[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )',
'time' : '([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )',
'date-time': 'date "T" time',
'date-string': '"\\"" date "\\"" space',
'time-string': '"\\"" time "\\"" space',
'date-time-string': '"\\"" date-time "\\"" space',
STRING_FORMAT_RULES = {
'date' : BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
'date-time': BuiltinRule('date "T" time', ['date', 'time']),
'date-string': BuiltinRule('"\\"" date "\\"" space', ['date']),
'time-string': BuiltinRule('"\\"" time "\\"" space', ['time']),
'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
}

DOTALL = '[\\U00000000-\\U0010FFFF]'
DOT = '[\\U00000000-\\x09\\x0B\\x0C\\x0E-\\U0010FFFF]'

RESERVED_NAMES = set(["root", *PRIMITIVE_RULES.keys(), *DATE_RULES.keys()])
RESERVED_NAMES = set(["root", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])

INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
Expand All @@ -54,8 +66,6 @@
NON_LITERAL_SET = set('|.()[]{}*+?')
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?')

DATE_PATTERN = '[0-9]{4}-(0[1-9]|1[0-2])-([0-2][0-9]|3[0-1])'
TIME_PATTERN = '([01][0-9]|2[0-3])(:[0-5][0-9]){2}(\\.[0-9]{1,3})?(Z|[+-](([01][0-9]|2[0-3]):[0-5][0-9]))' # Cap millisecond precision w/ 3 digits

class SchemaConverter:
def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
Expand All @@ -65,8 +75,6 @@ def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
self._raw_pattern = raw_pattern
self._rules = {
'space': SPACE_RULE,
'integral-part': PRIMITIVE_RULES['integral-part'],
'decimal-part': PRIMITIVE_RULES['decimal-part'],
}
self._refs = {}
self._refs_being_resolved = set()
Expand Down Expand Up @@ -420,7 +428,9 @@ def add_component(comp_schema, is_required):
successive_items = list_item_operator * (min_items - 1)
min_items -= 1
if max_items is not None and max_items > min_items:
successive_items += (list_item_operator + "?") * (max_items - min_items - 1)
# TODO: avoid grammar branch explosion here
successive_items += _build_repetition(list_item_operator, max_items - min_items - 1)
# successive_items += (list_item_operator + "?") * (max_items - min_items - 1)
else:
successive_items += list_item_operator + "*"
if min_items == 0:
Expand All @@ -433,28 +443,39 @@ def add_component(comp_schema, is_required):
return self._visit_pattern(schema['pattern'], rule_name)

elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
return self._add_rule(
return self._add_primitive(
'root' if rule_name == 'root' else schema_format,
PRIMITIVE_RULES['uuid']
)

elif schema_type in (None, 'string') and schema_format in DATE_RULES:
for t, r in DATE_RULES.items():
self._add_rule(t, r)
return schema_format + '-string'
elif schema_type in (None, 'string') and schema_format in STRING_FORMAT_RULES:
return self._add_rule(rule_name, self._add_primitive(schema_format, STRING_FORMAT_RULES[schema_format]))

elif (schema_type == 'object') or (len(schema) == 0):
for n in OBJECT_RULE_NAMES:
self._add_rule(n, PRIMITIVE_RULES[n])
return self._add_rule(rule_name, 'object')
return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))

else:
assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
return self._add_rule(
'root' if rule_name == 'root' else schema_type,
PRIMITIVE_RULES[schema_type]
)
return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type])

def _add_primitive(self, name: str, rule: BuiltinRule):
assert isinstance(rule, BuiltinRule), f'rule: {rule}'
assert isinstance(rule.content, str), f'{name}: {rule.content}'
n = self._add_rule(name, rule.content)

for dep in rule.deps:
dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
assert dep_rule, f'Rule {dep} not known'
if dep not in self._rules:
self._add_primitive(dep, dep_rule)
return n

def _build_number_rule(self):
_up_to_15_digits = _build_repetition('[0-9]', 15)
decimal_rule = self._add_rule('decimal-part', f'[0-9] {_up_to_15_digits}')
integral_rule = self._add_rule('integral-part', f'[0-9] | [1-9] {_up_to_15_digits}')
return self._add_rule('number', f'("-"? {integral_rule}) ("." {decimal_rule})? ([eE] [-+]? {integral_rule})? space')

def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
prop_order = self._prop_order
Expand All @@ -476,7 +497,7 @@ def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[st
value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value')
prop_kv_rule_names["*"] = self._add_rule(
f'{sub_name}-kv',
self._add_rule('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
)
optional_props.append("*")

Expand Down
Loading

0 comments on commit 375f85d

Please sign in to comment.