diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 45a1513..60d5194 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -58,7 +58,7 @@ class LlamaGrammar: ) # type: std.vector[std.vector[LlamaGrammarElement]] self._n_rules = self._grammar_rules.size() # type: int self._start_rule_index = parsed_grammar.symbol_ids.at("root") # type: int - self.grammar = self.init() + self.init() @classmethod def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar": diff --git a/tests/test_grammar.py b/tests/test_grammar.py new file mode 100644 index 0000000..2e24903 --- /dev/null +++ b/tests/test_grammar.py @@ -0,0 +1,13 @@ +import llama_cpp + +tree = """ +leaf ::= "." +node ::= leaf | "(" node node ")" +root ::= node +""" + +def test_grammar_from_string(): + grammar = llama_cpp.LlamaGrammar.from_string(tree) + assert grammar._n_rules == 3 + assert grammar._start_rule_index == 2 + assert grammar.grammar is not None