From b07713cb9f26d97e0d5f740e722e60ad0d7e9ddf Mon Sep 17 00:00:00 2001 From: c0sogi Date: Mon, 7 Aug 2023 15:16:25 +0900 Subject: [PATCH] reset grammar for every generation --- llama_cpp/llama.py | 9 ++- llama_cpp/llama_grammar.py | 125 +++++++++++-------------------------- 2 files changed, 39 insertions(+), 95 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index ab99ee5..9328c5b 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -364,7 +364,7 @@ class Llama: ) if grammar is not None: self.grammar = LlamaGrammar.from_file( - grammar + grammar, verbose=verbose ) # type: Optional[LlamaGrammar] else: self.grammar = None @@ -723,7 +723,6 @@ class Llama: The generated tokens. """ assert self.ctx is not None - if reset and len(self._input_ids) > 0: longest_prefix = 0 for a, b in zip(self._input_ids, tokens[:-1]): @@ -741,6 +740,9 @@ class Llama: if reset: self.reset() + if self.grammar is not None: + self.grammar.reset() + while True: self.eval(tokens) token = self.sample( @@ -1534,9 +1536,6 @@ class Llama: if self.ctx is not None: llama_cpp.llama_free(self.ctx) self.ctx = None - if self.grammar is not None: - llama_cpp.llama_grammar_free(self.grammar.grammar) - self.grammar = None def __getstate__(self): return dict( diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 06b2b7f..5388676 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -1,6 +1,5 @@ """C++ implementation of the llama grammar parser.""" # flake8: noqa -import argparse from pathlib import Path import sys from ctypes import * # type: ignore @@ -19,7 +18,7 @@ from typing import ( overload, ) -import llama_cpp +from . import llama_cpp # Type aliases llama_grammar_element = llama_cpp.llama_grammar_element @@ -41,11 +40,19 @@ class Sentinel: class LlamaGrammar: """Keeps reference counts of all the arguments, so that they are not garbage collected by Python.""" + + def __del__(self) -> None: + """Free the grammar pointer when the object is deleted.""" + if self.grammar is not None: + llama_cpp.llama_grammar_free(self.grammar) + self.grammar = None def __init__( self, parsed_grammar: "parse_state", ) -> None: + """Initialize the grammar pointer from the parsed state.""" + self.parsed_grammar = parsed_grammar grammar_rules = ( parsed_grammar.c_rules() ) # type: std.vector[std.vector[llama_grammar_element]] @@ -69,22 +76,25 @@ class LlamaGrammar: self.n_rules = c_size_t(grammar_rules.size()) self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root")) - self.grammar = self.init_grammar() + self._grammar = llama_cpp.llama_grammar_init( + self.rules, self.n_rules, self.start_rule_index + ) @classmethod - def from_string(cls, grammar: str) -> "LlamaGrammar": + def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar": parsed_grammar = parse(const_char_p(grammar)) # type: parse_state if parsed_grammar.rules.empty(): raise ValueError( f"{cls.from_string.__name__}: error parsing grammar file: parsed_grammar.rules is empty" ) - print(f"{cls.from_string.__name__} grammar:", file=sys.stderr) - print_grammar(sys.stdout, parsed_grammar) - print(file=sys.stderr) + if verbose: + print(f"{cls.from_string.__name__} grammar:", file=sys.stderr) + print_grammar(sys.stdout, parsed_grammar) + print(file=sys.stderr) return cls(parsed_grammar) @classmethod - def from_file(cls, file: Union[str, Path]) -> "LlamaGrammar": + def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar": try: with open(file) as f: grammar = f.read() @@ -94,14 +104,27 @@ class LlamaGrammar: ) if grammar: - return cls.from_string(grammar) + return cls.from_string(grammar, verbose=verbose) raise ValueError( f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty" ) - def init_grammar(self) -> llama_grammar_p: - return llama_cpp.llama_grammar_init( + @property + def grammar(self) -> llama_grammar_p: + if self._grammar is None: + raise ValueError( + f"{self.__class__.__name__}.grammar: grammar is freed" + ) + return self._grammar + + @grammar.setter + def grammar(self, value: Optional[llama_grammar_p]) -> None: + self._grammar = value + + def reset(self) -> None: + llama_cpp.llama_grammar_free(self.grammar) + self.grammar = llama_cpp.llama_grammar_init( self.rules, self.n_rules, self.start_rule_index ) @@ -1216,82 +1239,4 @@ def print_grammar(file: TextIO, state: parse_state) -> None: print( f"{print_grammar.__name__}: error printing grammar: {err}", file=sys.stderr, - ) - - -# def convert_to_rules( -# llama_grammar_elements: std.vector[std.vector[llama_grammar_element]], -# ) -> Array[llama_grammar_element_p]: -# """Make an Array object that is used for `llama_grammer_init`""" - -# # Step 1: Convert each list to llama_grammar_element array and get pointer -# element_arrays = [ -# (llama_grammar_element * len(subvector))(*subvector) -# for subvector in llama_grammar_elements -# ] # type: List[Array[llama_grammar_element]] - -# # Step 2: Get pointer of each array -# element_array_pointers = [ -# cast(subarray, llama_grammar_element_p) for subarray in element_arrays -# ] # type: List[llama_grammar_element_p] - -# # Step 3: Make array of these pointers and get its pointer -# return (llama_grammar_element_p * len(element_array_pointers))( -# *element_array_pointers -# ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Generate C++ parser from GBNF grammar" - ) - parser.add_argument( - "-g", - "--grammar", - type=str, - default="./vendor/llama.cpp/grammars/json.gbnf", - help="path to GBNF grammar file", - ) - - args = parser.parse_args() - llama_grammar = LlamaGrammar.from_file(Path(args.grammar)) - llama_grammar_ptr = llama_grammar.init_grammar() - - # ----- USAGE: - # llama_cpp.llama_sample_grammar(ctx=..., candidates=..., grammar=llama_grammar_p) - # llama_cpp.llama_grammar_accept_token(ctx=..., grammar=llama_grammar_p, token=...) - - # ----- SAMPLE OUTPUT: - # main grammar: - # root ::= object - # object ::= [{] ws object_11 [}] ws - # value ::= object | array | string | number | value_6 ws - # array ::= [[] ws array_15 []] ws - # string ::= ["] string_18 ["] ws - # number ::= number_19 number_25 number_29 ws - # value_6 ::= [t] [r] [u] [e] | [f] [a] [l] [s] [e] | [n] [u] [l] [l] - # ws ::= ws_31 - # object_8 ::= string [:] ws value object_10 - # object_9 ::= [,] ws string [:] ws value - # object_10 ::= object_9 object_10 | - # object_11 ::= object_8 | - # array_12 ::= value array_14 - # array_13 ::= [,] ws value - # array_14 ::= array_13 array_14 | - # array_15 ::= array_12 | - # string_16 ::= [^"\] | [\] string_17 - # string_17 ::= ["\/bfnrt] | [u] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] - # string_18 ::= string_16 string_18 | - # number_19 ::= number_20 number_21 - # number_20 ::= [-] | - # number_21 ::= [0-9] | [1-9] number_22 - # number_22 ::= [0-9] number_22 | - # number_23 ::= [.] number_24 - # number_24 ::= [0-9] number_24 | [0-9] - # number_25 ::= number_23 | - # number_26 ::= [eE] number_27 number_28 - # number_27 ::= [-+] | - # number_28 ::= [0-9] number_28 | [0-9] - # number_29 ::= number_26 | - # ws_30 ::= [ ] ws - # ws_31 ::= ws_30 | + ) \ No newline at end of file