fix: from_json_schema oneof/anyof bug. Closes #1097

This commit is contained in:
Andrei Betlen 2024-01-21 19:06:53 -05:00
parent 8eefdbca03
commit d3f5528ca8
2 changed files with 39 additions and 10 deletions

View file

@ -1432,7 +1432,6 @@ class SchemaConverter:
return key return key
def visit(self, schema: Dict[str, Any], name: str) -> str: def visit(self, schema: Dict[str, Any], name: str) -> str:
schema_type: Optional[str] = schema.get("type") # type: ignore
rule_name = name or "root" rule_name = name or "root"
if "$defs" in schema: if "$defs" in schema:
@ -1458,7 +1457,19 @@ class SchemaConverter:
rule = " | ".join((self._format_literal(v) for v in schema["enum"])) rule = " | ".join((self._format_literal(v) for v in schema["enum"]))
return self._add_rule(rule_name, rule) return self._add_rule(rule_name, rule)
elif schema_type == "object" and "properties" in schema: elif "$ref" in schema:
ref = schema["$ref"]
assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}"
# inline $defs
def_name = ref[len("#/$defs/") :]
def_schema = self._defs[def_name]
return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}')
schema_type: Optional[str] = schema.get("type") # type: ignore
assert isinstance(schema_type, str), f"Unrecognized schema: {schema}"
if schema_type == "object" and "properties" in schema:
# TODO: `required` keyword # TODO: `required` keyword
prop_order = self._prop_order prop_order = self._prop_order
prop_pairs = sorted( prop_pairs = sorted(
@ -1489,14 +1500,6 @@ class SchemaConverter:
) )
return self._add_rule(rule_name, rule) return self._add_rule(rule_name, rule)
elif "$ref" in schema:
ref = schema["$ref"]
assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}"
# inline $defs
def_name = ref[len("#/$defs/") :]
def_schema = self._defs[def_name]
return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}')
else: else:
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}" assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
return self._add_rule( return self._add_rule(

View file

@ -50,3 +50,29 @@ def test_composed_pydantic_grammar():
grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema)) grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema))
assert grammar.grammar is not None assert grammar.grammar is not None
def test_grammar_anyof():
sch = {
"properties": {
"temperature": {
"description": "The temperature mentioned",
"type": "number",
},
"unit": {
"anyOf": [
{
"description": "Unit for temperature",
"enum": ["celsius", "fahrenheit"],
"type": "string",
},
{"type": "null"},
],
},
},
"type": "object",
}
grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(sch))
assert grammar.grammar is not None