diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index c960e55..45a1513 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -1399,6 +1399,7 @@ class SchemaConverter: def __init__(self, prop_order): self._prop_order = prop_order self._rules = {"space": SPACE_RULE} + self._defs = {} def _format_literal(self, literal): escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( @@ -1422,6 +1423,11 @@ class SchemaConverter: schema_type = schema.get("type") rule_name = name or "root" + if "$defs" in schema: + # add defs to self._defs for later inlining + for def_name, def_schema in schema["$defs"].items(): + self._defs[def_name] = def_schema + if "oneOf" in schema or "anyOf" in schema: rule = " | ".join( ( @@ -1471,6 +1477,14 @@ class SchemaConverter: ) return self._add_rule(rule_name, rule) + elif "$ref" in schema: + ref = schema["$ref"] + assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}" + # inline $defs + def_name = ref[len("#/$defs/") :] + def_schema = self._defs[def_name] + return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}') + else: assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}" return self._add_rule(