From 770df344369c0630df1be14be9f9e301e7c56d24 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 10 Nov 2023 02:50:46 -0500 Subject: [PATCH] Add $ref and $defs support to json schema converter --- llama_cpp/llama_grammar.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index c960e55..45a1513 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -1399,6 +1399,7 @@ class SchemaConverter: def __init__(self, prop_order): self._prop_order = prop_order self._rules = {"space": SPACE_RULE} + self._defs = {} def _format_literal(self, literal): escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub( @@ -1422,6 +1423,11 @@ class SchemaConverter: schema_type = schema.get("type") rule_name = name or "root" + if "$defs" in schema: + # add defs to self._defs for later inlining + for def_name, def_schema in schema["$defs"].items(): + self._defs[def_name] = def_schema + if "oneOf" in schema or "anyOf" in schema: rule = " | ".join( ( @@ -1471,6 +1477,14 @@ class SchemaConverter: ) return self._add_rule(rule_name, rule) + elif "$ref" in schema: + ref = schema["$ref"] + assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}" + # inline $defs + def_name = ref[len("#/$defs/") :] + def_schema = self._defs[def_name] + return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}') + else: assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}" return self._add_rule(