fix: from_json_schema oneof/anyof bug. Closes #1097
This commit is contained in:
parent
8eefdbca03
commit
d3f5528ca8
2 changed files with 39 additions and 10 deletions
|
@ -1432,7 +1432,6 @@ class SchemaConverter:
|
||||||
return key
|
return key
|
||||||
|
|
||||||
def visit(self, schema: Dict[str, Any], name: str) -> str:
|
def visit(self, schema: Dict[str, Any], name: str) -> str:
|
||||||
schema_type: Optional[str] = schema.get("type") # type: ignore
|
|
||||||
rule_name = name or "root"
|
rule_name = name or "root"
|
||||||
|
|
||||||
if "$defs" in schema:
|
if "$defs" in schema:
|
||||||
|
@ -1458,7 +1457,19 @@ class SchemaConverter:
|
||||||
rule = " | ".join((self._format_literal(v) for v in schema["enum"]))
|
rule = " | ".join((self._format_literal(v) for v in schema["enum"]))
|
||||||
return self._add_rule(rule_name, rule)
|
return self._add_rule(rule_name, rule)
|
||||||
|
|
||||||
elif schema_type == "object" and "properties" in schema:
|
elif "$ref" in schema:
|
||||||
|
ref = schema["$ref"]
|
||||||
|
assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}"
|
||||||
|
# inline $defs
|
||||||
|
def_name = ref[len("#/$defs/") :]
|
||||||
|
def_schema = self._defs[def_name]
|
||||||
|
return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}')
|
||||||
|
|
||||||
|
|
||||||
|
schema_type: Optional[str] = schema.get("type") # type: ignore
|
||||||
|
assert isinstance(schema_type, str), f"Unrecognized schema: {schema}"
|
||||||
|
|
||||||
|
if schema_type == "object" and "properties" in schema:
|
||||||
# TODO: `required` keyword
|
# TODO: `required` keyword
|
||||||
prop_order = self._prop_order
|
prop_order = self._prop_order
|
||||||
prop_pairs = sorted(
|
prop_pairs = sorted(
|
||||||
|
@ -1489,14 +1500,6 @@ class SchemaConverter:
|
||||||
)
|
)
|
||||||
return self._add_rule(rule_name, rule)
|
return self._add_rule(rule_name, rule)
|
||||||
|
|
||||||
elif "$ref" in schema:
|
|
||||||
ref = schema["$ref"]
|
|
||||||
assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}"
|
|
||||||
# inline $defs
|
|
||||||
def_name = ref[len("#/$defs/") :]
|
|
||||||
def_schema = self._defs[def_name]
|
|
||||||
return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}')
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
|
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
|
||||||
return self._add_rule(
|
return self._add_rule(
|
||||||
|
|
|
@ -50,3 +50,29 @@ def test_composed_pydantic_grammar():
|
||||||
grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema))
|
grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema))
|
||||||
|
|
||||||
assert grammar.grammar is not None
|
assert grammar.grammar is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_grammar_anyof():
|
||||||
|
sch = {
|
||||||
|
"properties": {
|
||||||
|
"temperature": {
|
||||||
|
"description": "The temperature mentioned",
|
||||||
|
"type": "number",
|
||||||
|
},
|
||||||
|
"unit": {
|
||||||
|
"anyOf": [
|
||||||
|
{
|
||||||
|
"description": "Unit for temperature",
|
||||||
|
"enum": ["celsius", "fahrenheit"],
|
||||||
|
"type": "string",
|
||||||
|
},
|
||||||
|
{"type": "null"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
}
|
||||||
|
|
||||||
|
grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(sch))
|
||||||
|
|
||||||
|
assert grammar.grammar is not None
|
Loading…
Reference in a new issue