feat: update grammar schema converter to match llama.cpp (#1353)
* feat: improve function calling * feat:grammar * fix * fix * fix
This commit is contained in:
parent
fa4bb0cf81
commit
4f42664955
2 changed files with 523 additions and 106 deletions
|
@ -2709,4 +2709,4 @@ def chatml_function_calling(
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
raise ValueError("Automatic streaming tool choice is not supported")
|
raise ValueError("Automatic streaming tool choice is not supported")
|
|
@ -5,11 +5,12 @@ from pathlib import Path
|
||||||
import sys
|
import sys
|
||||||
from ctypes import * # type: ignore
|
from ctypes import * # type: ignore
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from itertools import islice
|
from itertools import islice, groupby
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
|
Set,
|
||||||
Generic,
|
Generic,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
@ -1391,145 +1392,561 @@ from typing import List, Optional
|
||||||
# whitespace. Also maybe improves generation quality?
|
# whitespace. Also maybe improves generation quality?
|
||||||
SPACE_RULE = '" "?'
|
SPACE_RULE = '" "?'
|
||||||
|
|
||||||
PRIMITIVE_RULES = {
|
|
||||||
"boolean": '("true" | "false") space',
|
|
||||||
"number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
|
|
||||||
"integer": '("-"? ([0-9] | [1-9] [0-9]*)) space',
|
|
||||||
"string": r""" "\"" (
|
|
||||||
[^"\\] |
|
|
||||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
|
||||||
)* "\"" space """,
|
|
||||||
"null": '"null" space',
|
|
||||||
}
|
|
||||||
|
|
||||||
INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+")
|
INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+")
|
||||||
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
|
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
|
||||||
GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'}
|
GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'}
|
||||||
|
|
||||||
|
# whitespace is constrained to a single space char to prevent model "running away" in
|
||||||
|
# whitespace. Also maybe improves generation quality?
|
||||||
|
SPACE_RULE = '" "?'
|
||||||
|
|
||||||
|
|
||||||
|
def _build_repetition(item_rule, min_items, max_items, separator_rule=None, item_rule_is_literal=False):
|
||||||
|
if not separator_rule:
|
||||||
|
if min_items == 0 and max_items == 1:
|
||||||
|
return f'{item_rule}?'
|
||||||
|
elif min_items == 1 and max_items is None:
|
||||||
|
return f'{item_rule}+'
|
||||||
|
|
||||||
|
result = ''
|
||||||
|
|
||||||
|
if min_items > 0:
|
||||||
|
if item_rule_is_literal and separator_rule is None:
|
||||||
|
result = '"' + (item_rule[1:-1] * min_items) + '"'
|
||||||
|
else:
|
||||||
|
result = (f' {separator_rule} ' if separator_rule else ' ').join([item_rule] * min_items)
|
||||||
|
|
||||||
|
def opt_repetitions(up_to_n, prefix_with_sep=False):
|
||||||
|
'''
|
||||||
|
- n=4, no sep: '(a (a (a (a)?)?)?)?'
|
||||||
|
- n=4, sep=',', prefix: '("," a ("," a ("," a ("," a)?)?)?)?'
|
||||||
|
- n=4, sep=',', no prefix: '(a ("," a ("," a ("," a)?)?)?)?'
|
||||||
|
'''
|
||||||
|
|
||||||
|
content = f'{separator_rule} {item_rule}' if prefix_with_sep and separator_rule else item_rule
|
||||||
|
if up_to_n == 0:
|
||||||
|
return ''
|
||||||
|
elif up_to_n == 1:
|
||||||
|
return f'({content})?'
|
||||||
|
elif separator_rule and not prefix_with_sep:
|
||||||
|
return f'({content} {opt_repetitions(up_to_n - 1, prefix_with_sep=True)})?'
|
||||||
|
else:
|
||||||
|
return (f'({content} ' * up_to_n).rstrip() + (')?' * up_to_n)
|
||||||
|
|
||||||
|
if min_items > 0 and max_items != min_items:
|
||||||
|
result += ' '
|
||||||
|
|
||||||
|
if max_items is not None:
|
||||||
|
result += opt_repetitions(max_items - min_items, prefix_with_sep=min_items > 0)
|
||||||
|
else:
|
||||||
|
item_operator = f'({separator_rule + " " if separator_rule else ""}{item_rule})'
|
||||||
|
|
||||||
|
if min_items == 0 and separator_rule:
|
||||||
|
result = f'({item_rule} {item_operator}*)?'
|
||||||
|
else:
|
||||||
|
result += f'{item_operator}*'
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class BuiltinRule:
|
||||||
|
def __init__(self, content: str, deps: list = None):
|
||||||
|
self.content = content
|
||||||
|
self.deps = deps or []
|
||||||
|
|
||||||
|
_up_to_15_digits = _build_repetition('[0-9]', 0, 15)
|
||||||
|
|
||||||
|
PRIMITIVE_RULES = {
|
||||||
|
'boolean' : BuiltinRule('("true" | "false") space', []),
|
||||||
|
'decimal-part' : BuiltinRule('[0-9] ' + _up_to_15_digits, []),
|
||||||
|
'integral-part': BuiltinRule('[0-9] | [1-9] ' + _up_to_15_digits, []),
|
||||||
|
'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']),
|
||||||
|
'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']),
|
||||||
|
'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']),
|
||||||
|
'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']),
|
||||||
|
'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']),
|
||||||
|
'uuid' : BuiltinRule(r'"\"" ' + ' "-" '.join('[0-9a-fA-F]' * n for n in [8, 4, 4, 4, 12]) + r' "\"" space', []),
|
||||||
|
'char' : BuiltinRule(r'[^"\\] | "\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])', []),
|
||||||
|
'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']),
|
||||||
|
'null' : BuiltinRule('"null" space', []),
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: support "uri", "email" string formats
|
||||||
|
STRING_FORMAT_RULES = {
|
||||||
|
'date' : BuiltinRule('[0-9] [0-9] [0-9] [0-9] "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []),
|
||||||
|
'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9] [0-9] [0-9] )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []),
|
||||||
|
'date-time' : BuiltinRule('date "T" time', ['date', 'time']),
|
||||||
|
'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']),
|
||||||
|
'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']),
|
||||||
|
'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']),
|
||||||
|
}
|
||||||
|
|
||||||
|
DOTALL = '[\\U00000000-\\U0010FFFF]'
|
||||||
|
DOT = '[^\\x0A\\x0D]'
|
||||||
|
|
||||||
|
RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
|
||||||
|
|
||||||
|
|
||||||
|
NON_LITERAL_SET = set('|.()[]{}*+?')
|
||||||
|
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('[]()|{}*+?')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SchemaConverter:
|
class SchemaConverter:
|
||||||
def __init__(self, prop_order):
|
def __init__(self, *, prop_order, allow_fetch, dotall, raw_pattern):
|
||||||
self._prop_order = prop_order
|
self._prop_order = prop_order
|
||||||
self._rules = {"space": SPACE_RULE}
|
self._allow_fetch = allow_fetch
|
||||||
self._defs: Dict[str, Any] = {}
|
self._dotall = dotall
|
||||||
|
self._raw_pattern = raw_pattern
|
||||||
|
self._rules = {
|
||||||
|
'space': SPACE_RULE,
|
||||||
|
}
|
||||||
|
self._refs = {}
|
||||||
|
self._refs_being_resolved = set()
|
||||||
|
|
||||||
def _format_literal(self, literal: str):
|
def _format_literal(self, literal):
|
||||||
escaped: str = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
||||||
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal)
|
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), literal
|
||||||
)
|
)
|
||||||
return f'"{escaped}"'
|
return f'"{escaped}"'
|
||||||
|
|
||||||
def _add_rule(self, name: str, rule: str):
|
def not_literal(self, literal: str, dotall: bool = True, maybe_escaped_underscores = False) -> str:
|
||||||
esc_name = INVALID_RULE_CHARS_RE.sub("-", name)
|
'''
|
||||||
|
not_literal('a') -> '[^a]'
|
||||||
|
not_literal('abc') -> '([^a] | "a" ([^b] | "b" ([^c])?)?)?'
|
||||||
|
'''
|
||||||
|
assert len(literal) > 0, 'Empty literal not supported'
|
||||||
|
def recurse(i: int):
|
||||||
|
c = literal[i]
|
||||||
|
if maybe_escaped_underscores and c == '_':
|
||||||
|
yield f'[^{c}\\\\]'
|
||||||
|
yield ' | '
|
||||||
|
yield f'"\\\\"? "{c}"'
|
||||||
|
else:
|
||||||
|
yield f'[^{c}]'
|
||||||
|
if i < len(literal) - 1:
|
||||||
|
yield ' | '
|
||||||
|
yield self._format_literal(c)
|
||||||
|
yield ' ('
|
||||||
|
yield from recurse(i + 1)
|
||||||
|
yield ')?'
|
||||||
|
|
||||||
|
return ''.join(('(', *recurse(0), ')'))
|
||||||
|
|
||||||
|
def _add_rule(self, name, rule):
|
||||||
|
esc_name = INVALID_RULE_CHARS_RE.sub('-', name)
|
||||||
if esc_name not in self._rules or self._rules[esc_name] == rule:
|
if esc_name not in self._rules or self._rules[esc_name] == rule:
|
||||||
key = esc_name
|
key = esc_name
|
||||||
else:
|
else:
|
||||||
i = 0
|
i = 0
|
||||||
while f"{esc_name}{i}" in self._rules:
|
while f'{esc_name}{i}' in self._rules and self._rules[f'{esc_name}{i}'] != rule:
|
||||||
i += 1
|
i += 1
|
||||||
key = f"{esc_name}{i}"
|
key = f'{esc_name}{i}'
|
||||||
self._rules[key] = rule
|
self._rules[key] = rule
|
||||||
return key
|
return key
|
||||||
|
|
||||||
def visit(self, schema: Dict[str, Any], name: str) -> str:
|
def resolve_refs(self, schema: dict, url: str):
|
||||||
rule_name = name or "root"
|
'''
|
||||||
|
Resolves all $ref fields in the given schema, fetching any remote schemas,
|
||||||
|
replacing $ref with absolute reference URL and populating self._refs with the
|
||||||
|
respective referenced (sub)schema dictionaries.
|
||||||
|
'''
|
||||||
|
def visit(n: dict):
|
||||||
|
if isinstance(n, list):
|
||||||
|
return [visit(x) for x in n]
|
||||||
|
elif isinstance(n, dict):
|
||||||
|
ref = n.get('$ref')
|
||||||
|
if ref is not None and ref not in self._refs:
|
||||||
|
if ref.startswith('https://'):
|
||||||
|
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
|
||||||
|
import requests
|
||||||
|
|
||||||
if "$defs" in schema:
|
frag_split = ref.split('#')
|
||||||
# add defs to self._defs for later inlining
|
base_url = frag_split[0]
|
||||||
for def_name, def_schema in schema["$defs"].items():
|
|
||||||
self._defs[def_name] = def_schema
|
|
||||||
|
|
||||||
if "oneOf" in schema or "anyOf" in schema:
|
target = self._refs.get(base_url)
|
||||||
rule = " | ".join(
|
if target is None:
|
||||||
(
|
target = self.resolve_refs(requests.get(ref).json(), base_url)
|
||||||
self.visit(alt_schema, f'{name}{"-" if name else ""}{i}')
|
self._refs[base_url] = target
|
||||||
for i, alt_schema in enumerate(
|
|
||||||
schema.get("oneOf") or schema["anyOf"]
|
if len(frag_split) == 1 or frag_split[-1] == '':
|
||||||
)
|
return target
|
||||||
)
|
elif ref.startswith('#/'):
|
||||||
|
target = schema
|
||||||
|
ref = f'{url}{ref}'
|
||||||
|
n['$ref'] = ref
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Unsupported ref {ref}')
|
||||||
|
|
||||||
|
for sel in ref.split('#')[-1].split('/')[1:]:
|
||||||
|
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
|
||||||
|
target = target[sel]
|
||||||
|
|
||||||
|
self._refs[ref] = target
|
||||||
|
else:
|
||||||
|
for v in n.values():
|
||||||
|
visit(v)
|
||||||
|
|
||||||
|
return n
|
||||||
|
return visit(schema)
|
||||||
|
|
||||||
|
def _generate_union_rule(self, name, alt_schemas):
|
||||||
|
return ' | '.join((
|
||||||
|
self.visit(alt_schema, f'{name}{"-" if name else "alternative-"}{i}')
|
||||||
|
for i, alt_schema in enumerate(alt_schemas)
|
||||||
|
))
|
||||||
|
|
||||||
|
def _visit_pattern(self, pattern, name):
|
||||||
|
'''
|
||||||
|
Transforms a regular expression pattern into a GBNF rule.
|
||||||
|
|
||||||
|
Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
|
||||||
|
Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md
|
||||||
|
|
||||||
|
Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.
|
||||||
|
|
||||||
|
Mostly a 1:1 translation, except for {x} / {x,} / {x,y} quantifiers for which
|
||||||
|
we define sub-rules to keep the output lean.
|
||||||
|
'''
|
||||||
|
|
||||||
|
assert pattern.startswith('^') and pattern.endswith('$'), 'Pattern must start with "^" and end with "$"'
|
||||||
|
pattern = pattern[1:-1]
|
||||||
|
sub_rule_ids = {}
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
length = len(pattern)
|
||||||
|
|
||||||
|
def to_rule(s: Tuple[str, bool]) -> str:
|
||||||
|
(txt, is_literal) = s
|
||||||
|
return "\"" + txt + "\"" if is_literal else txt
|
||||||
|
|
||||||
|
def transform() -> Tuple[str, bool]:
|
||||||
|
'''
|
||||||
|
Parse a unit at index i (advancing it), and return its string representation + whether it's a literal.
|
||||||
|
'''
|
||||||
|
nonlocal i
|
||||||
|
nonlocal pattern
|
||||||
|
nonlocal sub_rule_ids
|
||||||
|
|
||||||
|
start = i
|
||||||
|
# For each component of this sequence, store its string representation and whether it's a literal.
|
||||||
|
# We only need a flat structure here to apply repetition operators to the last item, and
|
||||||
|
# to merge literals at the and (we're parsing grouped ( sequences ) recursively and don't treat '|' specially
|
||||||
|
# (GBNF's syntax is luckily very close to regular expressions!)
|
||||||
|
seq: list[Tuple[str, bool]] = []
|
||||||
|
|
||||||
|
def get_dot():
|
||||||
|
if self._dotall:
|
||||||
|
rule = DOTALL
|
||||||
|
else:
|
||||||
|
# Accept any character... except \n and \r line break chars (\x0A and \xOD)
|
||||||
|
rule = DOT
|
||||||
|
return self._add_rule(f'dot', rule)
|
||||||
|
|
||||||
|
def join_seq():
|
||||||
|
nonlocal seq
|
||||||
|
ret = []
|
||||||
|
for is_literal, g in groupby(seq, lambda x: x[1]):
|
||||||
|
if is_literal:
|
||||||
|
ret.append((''.join(x[0] for x in g), True))
|
||||||
|
else:
|
||||||
|
ret.extend(g)
|
||||||
|
if len(ret) == 1:
|
||||||
|
return ret[0]
|
||||||
|
return (' '.join(to_rule(x) for x in seq), False)
|
||||||
|
|
||||||
|
while i < length:
|
||||||
|
c = pattern[i]
|
||||||
|
if c == '.':
|
||||||
|
seq.append((get_dot(), False))
|
||||||
|
i += 1
|
||||||
|
elif c == '(':
|
||||||
|
i += 1
|
||||||
|
if i < length:
|
||||||
|
assert pattern[i] != '?', f'Unsupported pattern syntax "{pattern[i]}" at index {i} of /{pattern}/'
|
||||||
|
seq.append((f'({to_rule(transform())})', False))
|
||||||
|
elif c == ')':
|
||||||
|
i += 1
|
||||||
|
assert start > 0 and pattern[start-1] == '(', f'Unbalanced parentheses; start = {start}, i = {i}, pattern = {pattern}'
|
||||||
|
return join_seq()
|
||||||
|
elif c == '[':
|
||||||
|
square_brackets = c
|
||||||
|
i += 1
|
||||||
|
while i < length and pattern[i] != ']':
|
||||||
|
if pattern[i] == '\\':
|
||||||
|
square_brackets += pattern[i:i+2]
|
||||||
|
i += 2
|
||||||
|
else:
|
||||||
|
square_brackets += pattern[i]
|
||||||
|
i += 1
|
||||||
|
assert i < length, f'Unbalanced square brackets; start = {start}, i = {i}, pattern = {pattern}'
|
||||||
|
square_brackets += ']'
|
||||||
|
i += 1
|
||||||
|
seq.append((square_brackets, False))
|
||||||
|
elif c == '|':
|
||||||
|
seq.append(('|', False))
|
||||||
|
i += 1
|
||||||
|
elif c in ('*', '+', '?'):
|
||||||
|
seq[-1] = (to_rule(seq[-1]) + c, False)
|
||||||
|
i += 1
|
||||||
|
elif c == '{':
|
||||||
|
curly_brackets = c
|
||||||
|
i += 1
|
||||||
|
while i < length and pattern[i] != '}':
|
||||||
|
curly_brackets += pattern[i]
|
||||||
|
i += 1
|
||||||
|
assert i < length, f'Unbalanced curly brackets; start = {start}, i = {i}, pattern = {pattern}'
|
||||||
|
curly_brackets += '}'
|
||||||
|
i += 1
|
||||||
|
nums = [s.strip() for s in curly_brackets[1:-1].split(',')]
|
||||||
|
min_times = 0
|
||||||
|
max_times = None
|
||||||
|
try:
|
||||||
|
if len(nums) == 1:
|
||||||
|
min_times = int(nums[0])
|
||||||
|
max_times = min_times
|
||||||
|
else:
|
||||||
|
assert len(nums) == 2
|
||||||
|
min_times = int(nums[0]) if nums[0] else 0
|
||||||
|
max_times = int(nums[1]) if nums[1] else None
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(f'Invalid quantifier {curly_brackets} in /{pattern}/')
|
||||||
|
|
||||||
|
(sub, sub_is_literal) = seq[-1]
|
||||||
|
|
||||||
|
if not sub_is_literal:
|
||||||
|
id = sub_rule_ids.get(sub)
|
||||||
|
if id is None:
|
||||||
|
id = self._add_rule(f'{name}-{len(sub_rule_ids) + 1}', sub)
|
||||||
|
sub_rule_ids[sub] = id
|
||||||
|
sub = id
|
||||||
|
|
||||||
|
seq[-1] = (_build_repetition(f'"{sub}"' if sub_is_literal else sub, min_times, max_times, item_rule_is_literal=sub_is_literal), False)
|
||||||
|
else:
|
||||||
|
literal = ''
|
||||||
|
while i < length:
|
||||||
|
if pattern[i] == '\\' and i < length - 1:
|
||||||
|
next = pattern[i + 1]
|
||||||
|
if next in ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS:
|
||||||
|
i += 1
|
||||||
|
literal += pattern[i]
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
literal += pattern[i:i+2]
|
||||||
|
i += 2
|
||||||
|
elif pattern[i] == '"' and not self._raw_pattern:
|
||||||
|
literal += '\\"'
|
||||||
|
i += 1
|
||||||
|
elif pattern[i] not in NON_LITERAL_SET and \
|
||||||
|
(i == length - 1 or literal == '' or pattern[i+1] == '.' or pattern[i+1] not in NON_LITERAL_SET):
|
||||||
|
literal += pattern[i]
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
if literal:
|
||||||
|
seq.append((literal, True))
|
||||||
|
|
||||||
|
return join_seq()
|
||||||
|
|
||||||
|
return self._add_rule(
|
||||||
|
name,
|
||||||
|
to_rule(transform()) if self._raw_pattern \
|
||||||
|
else "\"\\\"\" " + to_rule(transform()) + " \"\\\"\" space")
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_ref(self, ref):
|
||||||
|
ref_name = ref.split('/')[-1]
|
||||||
|
if ref_name not in self._rules and ref not in self._refs_being_resolved:
|
||||||
|
self._refs_being_resolved.add(ref)
|
||||||
|
resolved = self._refs[ref]
|
||||||
|
ref_name = self.visit(resolved, ref_name)
|
||||||
|
self._refs_being_resolved.remove(ref)
|
||||||
|
return ref_name
|
||||||
|
|
||||||
|
def _generate_constant_rule(self, value):
|
||||||
|
return self._format_literal(json.dumps(value))
|
||||||
|
|
||||||
|
def visit(self, schema, name):
|
||||||
|
schema_type = schema.get('type')
|
||||||
|
schema_format = schema.get('format')
|
||||||
|
rule_name = name + '-' if name in RESERVED_NAMES else name or 'root'
|
||||||
|
|
||||||
|
if (ref := schema.get('$ref')) is not None:
|
||||||
|
return self._add_rule(rule_name, self._resolve_ref(ref))
|
||||||
|
|
||||||
|
elif 'oneOf' in schema or 'anyOf' in schema:
|
||||||
|
return self._add_rule(rule_name, self._generate_union_rule(name, schema.get('oneOf') or schema['anyOf']))
|
||||||
|
|
||||||
|
elif isinstance(schema_type, list):
|
||||||
|
return self._add_rule(rule_name, self._generate_union_rule(name, [{'type': t} for t in schema_type]))
|
||||||
|
|
||||||
|
elif 'const' in schema:
|
||||||
|
return self._add_rule(rule_name, self._generate_constant_rule(schema['const']))
|
||||||
|
|
||||||
|
elif 'enum' in schema:
|
||||||
|
rule = ' | '.join((self._generate_constant_rule(v) for v in schema['enum']))
|
||||||
|
return self._add_rule(rule_name, rule)
|
||||||
|
|
||||||
|
elif schema_type in (None, 'object') and \
|
||||||
|
('properties' in schema or \
|
||||||
|
('additionalProperties' in schema and schema['additionalProperties'] is not True)):
|
||||||
|
required = set(schema.get('required', []))
|
||||||
|
properties = list(schema.get('properties', {}).items())
|
||||||
|
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
|
||||||
|
|
||||||
|
elif schema_type in (None, 'object') and 'allOf' in schema:
|
||||||
|
required = set()
|
||||||
|
properties = []
|
||||||
|
hybrid_name = name
|
||||||
|
def add_component(comp_schema, is_required):
|
||||||
|
if (ref := comp_schema.get('$ref')) is not None:
|
||||||
|
comp_schema = self._refs[ref]
|
||||||
|
|
||||||
|
if 'properties' in comp_schema:
|
||||||
|
for prop_name, prop_schema in comp_schema['properties'].items():
|
||||||
|
properties.append((prop_name, prop_schema))
|
||||||
|
if is_required:
|
||||||
|
required.add(prop_name)
|
||||||
|
|
||||||
|
for t in schema['allOf']:
|
||||||
|
if 'anyOf' in t:
|
||||||
|
for tt in t['anyOf']:
|
||||||
|
add_component(tt, is_required=False)
|
||||||
|
else:
|
||||||
|
add_component(t, is_required=True)
|
||||||
|
|
||||||
|
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=[]))
|
||||||
|
|
||||||
|
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
|
||||||
|
items = schema.get('items') or schema['prefixItems']
|
||||||
|
if isinstance(items, list):
|
||||||
|
return self._add_rule(
|
||||||
|
rule_name,
|
||||||
|
'"[" space ' +
|
||||||
|
' "," space '.join(
|
||||||
|
self.visit(item, f'{name}{"-" if name else ""}tuple-{i}')
|
||||||
|
for i, item in enumerate(items)) +
|
||||||
|
' "]" space')
|
||||||
|
else:
|
||||||
|
item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item')
|
||||||
|
min_items = schema.get("minItems", 0)
|
||||||
|
max_items = schema.get("maxItems")
|
||||||
|
return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space')
|
||||||
|
|
||||||
|
elif schema_type in (None, 'string') and 'pattern' in schema:
|
||||||
|
return self._visit_pattern(schema['pattern'], rule_name)
|
||||||
|
|
||||||
|
elif schema_type in (None, 'string') and re.match(r'^uuid[1-5]?$', schema_format or ''):
|
||||||
|
return self._add_primitive(
|
||||||
|
'root' if rule_name == 'root' else schema_format,
|
||||||
|
PRIMITIVE_RULES['uuid']
|
||||||
)
|
)
|
||||||
return self._add_rule(rule_name, rule)
|
|
||||||
|
|
||||||
elif "const" in schema:
|
elif schema_type in (None, 'string') and f'{schema_format}-string' in STRING_FORMAT_RULES:
|
||||||
return self._add_rule(rule_name, self._format_literal(schema["const"]))
|
prim_name = f'{schema_format}-string'
|
||||||
|
return self._add_rule(rule_name, self._add_primitive(prim_name, STRING_FORMAT_RULES[prim_name]))
|
||||||
|
|
||||||
elif "enum" in schema:
|
elif schema_type == 'string' and ('minLength' in schema or 'maxLength' in schema):
|
||||||
rule = " | ".join((self._format_literal(v) for v in schema["enum"]))
|
char_rule = self._add_primitive('char', PRIMITIVE_RULES['char'])
|
||||||
return self._add_rule(rule_name, rule)
|
min_len = schema.get('minLength', 0)
|
||||||
|
max_len = schema.get('maxLength')
|
||||||
|
|
||||||
elif "$ref" in schema:
|
return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space')
|
||||||
ref = schema["$ref"]
|
|
||||||
assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}"
|
|
||||||
# inline $defs
|
|
||||||
def_name = ref[len("#/$defs/") :]
|
|
||||||
def_schema = self._defs[def_name]
|
|
||||||
return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}')
|
|
||||||
|
|
||||||
|
elif (schema_type == 'object') or (len(schema) == 0):
|
||||||
schema_type: Optional[str] = schema.get("type") # type: ignore
|
return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object']))
|
||||||
assert isinstance(schema_type, str), f"Unrecognized schema: {schema}"
|
|
||||||
|
|
||||||
if schema_type == "object" and "properties" in schema:
|
|
||||||
# TODO: `required` keyword
|
|
||||||
if self._prop_order:
|
|
||||||
prop_order = self._prop_order
|
|
||||||
prop_pairs = sorted(
|
|
||||||
schema["properties"].items(),
|
|
||||||
# sort by position in prop_order (if specified) then by key
|
|
||||||
key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
prop_pairs = schema["properties"].items()
|
|
||||||
|
|
||||||
rule = '"{" space'
|
|
||||||
for i, (prop_name, prop_schema) in enumerate(prop_pairs):
|
|
||||||
prop_rule_name = self.visit(
|
|
||||||
prop_schema, f'{name}{"-" if name else ""}{prop_name}'
|
|
||||||
)
|
|
||||||
if i > 0:
|
|
||||||
rule += ' "," space'
|
|
||||||
rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}'
|
|
||||||
rule += ' "}" space'
|
|
||||||
|
|
||||||
return self._add_rule(rule_name, rule)
|
|
||||||
|
|
||||||
elif schema_type == "array" and "items" in schema:
|
|
||||||
# TODO `prefixItems` keyword
|
|
||||||
item_rule_name = self.visit(
|
|
||||||
schema["items"], f'{name}{"-" if name else ""}item'
|
|
||||||
)
|
|
||||||
list_item_operator = f'("," space {item_rule_name})'
|
|
||||||
successive_items = ""
|
|
||||||
min_items = schema.get("minItems", 0)
|
|
||||||
if min_items > 0:
|
|
||||||
first_item = f"({item_rule_name})"
|
|
||||||
successive_items = list_item_operator * (min_items - 1)
|
|
||||||
min_items -= 1
|
|
||||||
else:
|
|
||||||
first_item = f"({item_rule_name})?"
|
|
||||||
max_items = schema.get("maxItems")
|
|
||||||
if max_items is not None and max_items > min_items:
|
|
||||||
successive_items += (list_item_operator + "?") * (max_items - min_items - 1)
|
|
||||||
else:
|
|
||||||
successive_items += list_item_operator + "*"
|
|
||||||
rule = f'"[" space {first_item} {successive_items} "]" space'
|
|
||||||
return self._add_rule(rule_name, rule)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
|
assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}'
|
||||||
return self._add_rule(
|
# TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero
|
||||||
"root" if rule_name == "root" else schema_type,
|
return self._add_primitive('root' if rule_name == 'root' else schema_type, PRIMITIVE_RULES[schema_type])
|
||||||
PRIMITIVE_RULES[schema_type],
|
|
||||||
|
def _add_primitive(self, name: str, rule: BuiltinRule):
|
||||||
|
n = self._add_rule(name, rule.content)
|
||||||
|
|
||||||
|
for dep in rule.deps:
|
||||||
|
dep_rule = PRIMITIVE_RULES.get(dep) or STRING_FORMAT_RULES.get(dep)
|
||||||
|
assert dep_rule, f'Rule {dep} not known'
|
||||||
|
if dep not in self._rules:
|
||||||
|
self._add_primitive(dep, dep_rule)
|
||||||
|
return n
|
||||||
|
|
||||||
|
def _build_object_rule(self, properties: List[Tuple[str, Any]], required: Set[str], name: str, additional_properties: Union[bool, Any]):
|
||||||
|
prop_order = self._prop_order
|
||||||
|
# sort by position in prop_order (if specified) then by original order
|
||||||
|
sorted_props = [kv[0] for _, kv in sorted(enumerate(properties), key=lambda ikv: (prop_order.get(ikv[1][0], len(prop_order)), ikv[0]))]
|
||||||
|
|
||||||
|
prop_kv_rule_names = {}
|
||||||
|
for prop_name, prop_schema in properties:
|
||||||
|
prop_rule_name = self.visit(prop_schema, f'{name}{"-" if name else ""}{prop_name}')
|
||||||
|
prop_kv_rule_names[prop_name] = self._add_rule(
|
||||||
|
f'{name}{"-" if name else ""}{prop_name}-kv',
|
||||||
|
fr'{self._format_literal(json.dumps(prop_name))} space ":" space {prop_rule_name}'
|
||||||
)
|
)
|
||||||
|
required_props = [k for k in sorted_props if k in required]
|
||||||
|
optional_props = [k for k in sorted_props if k not in required]
|
||||||
|
|
||||||
|
if additional_properties == True or isinstance(additional_properties, dict):
|
||||||
|
sub_name = f'{name}{"-" if name else ""}additional'
|
||||||
|
value_rule = self.visit({} if additional_properties == True else additional_properties, f'{sub_name}-value')
|
||||||
|
prop_kv_rule_names["*"] = self._add_rule(
|
||||||
|
f'{sub_name}-kv',
|
||||||
|
self._add_primitive('string', PRIMITIVE_RULES['string']) + f' ":" space {value_rule}'
|
||||||
|
)
|
||||||
|
optional_props.append("*")
|
||||||
|
|
||||||
|
rule = '"{" space '
|
||||||
|
rule += ' "," space '.join(prop_kv_rule_names[k] for k in required_props)
|
||||||
|
|
||||||
|
if optional_props:
|
||||||
|
rule += ' ('
|
||||||
|
if required_props:
|
||||||
|
rule += ' "," space ( '
|
||||||
|
|
||||||
|
def get_recursive_refs(ks, first_is_optional):
|
||||||
|
[k, *rest] = ks
|
||||||
|
kv_rule_name = prop_kv_rule_names[k]
|
||||||
|
if k == '*':
|
||||||
|
res = self._add_rule(
|
||||||
|
f'{name}{"-" if name else ""}additional-kvs',
|
||||||
|
f'{kv_rule_name} ( "," space ' + kv_rule_name + ' )*'
|
||||||
|
)
|
||||||
|
elif first_is_optional:
|
||||||
|
res = f'( "," space {kv_rule_name} )?'
|
||||||
|
else:
|
||||||
|
res = kv_rule_name
|
||||||
|
if len(rest) > 0:
|
||||||
|
res += ' ' + self._add_rule(
|
||||||
|
f'{name}{"-" if name else ""}{k}-rest',
|
||||||
|
get_recursive_refs(rest, first_is_optional=True)
|
||||||
|
)
|
||||||
|
return res
|
||||||
|
|
||||||
|
rule += ' | '.join(
|
||||||
|
get_recursive_refs(optional_props[i:], first_is_optional=False)
|
||||||
|
for i in range(len(optional_props))
|
||||||
|
)
|
||||||
|
if required_props:
|
||||||
|
rule += ' )'
|
||||||
|
rule += ' )?'
|
||||||
|
|
||||||
|
rule += ' "}" space'
|
||||||
|
|
||||||
|
return rule
|
||||||
|
|
||||||
def format_grammar(self):
|
def format_grammar(self):
|
||||||
return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items()))
|
return '\n'.join(
|
||||||
|
f'{name} ::= {rule}'
|
||||||
|
for name, rule in sorted(self._rules.items(), key=lambda kv: kv[0])
|
||||||
|
)
|
||||||
def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None):
|
def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None):
|
||||||
prop_order = prop_order or []
|
prop_order = prop_order or []
|
||||||
schema = json.loads(schema)
|
schema = json.loads(schema)
|
||||||
prop_order = {name: idx for idx, name in enumerate(prop_order)}
|
prop_order = {name: idx for idx, name in enumerate(prop_order)}
|
||||||
converter = SchemaConverter(prop_order)
|
converter = SchemaConverter(prop_order=prop_order, allow_fetch=False, dotall=False, raw_pattern=False)
|
||||||
|
schema = converter.resolve_refs(schema, "stdin")
|
||||||
converter.visit(schema, "")
|
converter.visit(schema, "")
|
||||||
return converter.format_grammar()
|
return converter.format_grammar()
|
||||||
|
|
Loading…
Reference in a new issue