diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index c02e656..d8ef563 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -1432,7 +1432,6 @@ class SchemaConverter: return key def visit(self, schema: Dict[str, Any], name: str) -> str: - schema_type: Optional[str] = schema.get("type") # type: ignore rule_name = name or "root" if "$defs" in schema: @@ -1458,7 +1457,19 @@ class SchemaConverter: rule = " | ".join((self._format_literal(v) for v in schema["enum"])) return self._add_rule(rule_name, rule) - elif schema_type == "object" and "properties" in schema: + elif "$ref" in schema: + ref = schema["$ref"] + assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}" + # inline $defs + def_name = ref[len("#/$defs/") :] + def_schema = self._defs[def_name] + return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}') + + + schema_type: Optional[str] = schema.get("type") # type: ignore + assert isinstance(schema_type, str), f"Unrecognized schema: {schema}" + + if schema_type == "object" and "properties" in schema: # TODO: `required` keyword prop_order = self._prop_order prop_pairs = sorted( @@ -1489,14 +1500,6 @@ class SchemaConverter: ) return self._add_rule(rule_name, rule) - elif "$ref" in schema: - ref = schema["$ref"] - assert ref.startswith("#/$defs/"), f"Unrecognized schema: {schema}" - # inline $defs - def_name = ref[len("#/$defs/") :] - def_schema = self._defs[def_name] - return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}') - else: assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}" return self._add_rule( diff --git a/tests/test_grammar.py b/tests/test_grammar.py index ef9392b..cb22188 100644 --- a/tests/test_grammar.py +++ b/tests/test_grammar.py @@ -50,3 +50,29 @@ def test_composed_pydantic_grammar(): grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema)) assert grammar.grammar is not None + + +def test_grammar_anyof(): + sch = { + "properties": { + "temperature": { + "description": "The temperature mentioned", + "type": "number", + }, + "unit": { + "anyOf": [ + { + "description": "Unit for temperature", + "enum": ["celsius", "fahrenheit"], + "type": "string", + }, + {"type": "null"}, + ], + }, + }, + "type": "object", + } + + grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(sch)) + + assert grammar.grammar is not None \ No newline at end of file