Fix Pydantic model parsing (#1087)

This commit is contained in:
Mark Neumann 2024-01-15 07:45:57 -08:00 committed by GitHub
parent 5502ac8876
commit c689ccc728
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 1 deletions

View file

@ -1433,7 +1433,6 @@ class SchemaConverter:
def visit(self, schema: Dict[str, Any], name: str) -> str:
schema_type: Optional[str] = schema.get("type") # type: ignore
assert isinstance(schema_type, str), f"Unrecognized schema: {schema}"
rule_name = name or "root"
if "$defs" in schema:

View file

@ -1,4 +1,5 @@
import llama_cpp
import json
tree = """
leaf ::= "."
@ -6,8 +7,46 @@ node ::= leaf | "(" node node ")"
root ::= node
"""
def test_grammar_from_string():
grammar = llama_cpp.LlamaGrammar.from_string(tree)
assert grammar._n_rules == 3
assert grammar._start_rule_index == 2
assert grammar.grammar is not None
def test_composed_pydantic_grammar():
"""
from pydantic import BaseModel
class A(BaseModel):
a: int
class B(BaseModel):
a: A
b: int
"""
# This schema corresponds to the grammar in the comment above.
# We don't use the pydantic models directly to avoid the dependency.
schema = {
"$defs": {
"A": {
"properties": {"a": {"title": "A", "type": "integer"}},
"required": ["a"],
"title": "A",
"type": "object",
}
},
"properties": {
"a": {"$ref": "#/$defs/A"},
"b": {"title": "B", "type": "integer"},
},
"required": ["a", "b"],
"title": "B",
"type": "object",
}
grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema))
assert grammar.grammar is not None