reset grammar for every generation

This commit is contained in:
c0sogi 2023-08-07 15:16:25 +09:00
parent 418aa83b01
commit b07713cb9f
2 changed files with 39 additions and 95 deletions

View file

@ -364,7 +364,7 @@ class Llama:
) )
if grammar is not None: if grammar is not None:
self.grammar = LlamaGrammar.from_file( self.grammar = LlamaGrammar.from_file(
grammar grammar, verbose=verbose
) # type: Optional[LlamaGrammar] ) # type: Optional[LlamaGrammar]
else: else:
self.grammar = None self.grammar = None
@ -723,7 +723,6 @@ class Llama:
The generated tokens. The generated tokens.
""" """
assert self.ctx is not None assert self.ctx is not None
if reset and len(self._input_ids) > 0: if reset and len(self._input_ids) > 0:
longest_prefix = 0 longest_prefix = 0
for a, b in zip(self._input_ids, tokens[:-1]): for a, b in zip(self._input_ids, tokens[:-1]):
@ -741,6 +740,9 @@ class Llama:
if reset: if reset:
self.reset() self.reset()
if self.grammar is not None:
self.grammar.reset()
while True: while True:
self.eval(tokens) self.eval(tokens)
token = self.sample( token = self.sample(
@ -1534,9 +1536,6 @@ class Llama:
if self.ctx is not None: if self.ctx is not None:
llama_cpp.llama_free(self.ctx) llama_cpp.llama_free(self.ctx)
self.ctx = None self.ctx = None
if self.grammar is not None:
llama_cpp.llama_grammar_free(self.grammar.grammar)
self.grammar = None
def __getstate__(self): def __getstate__(self):
return dict( return dict(

View file

@ -1,6 +1,5 @@
"""C++ implementation of the llama grammar parser.""" """C++ implementation of the llama grammar parser."""
# flake8: noqa # flake8: noqa
import argparse
from pathlib import Path from pathlib import Path
import sys import sys
from ctypes import * # type: ignore from ctypes import * # type: ignore
@ -19,7 +18,7 @@ from typing import (
overload, overload,
) )
import llama_cpp from . import llama_cpp
# Type aliases # Type aliases
llama_grammar_element = llama_cpp.llama_grammar_element llama_grammar_element = llama_cpp.llama_grammar_element
@ -42,10 +41,18 @@ class LlamaGrammar:
"""Keeps reference counts of all the arguments, so that they are not """Keeps reference counts of all the arguments, so that they are not
garbage collected by Python.""" garbage collected by Python."""
def __del__(self) -> None:
"""Free the grammar pointer when the object is deleted."""
if self.grammar is not None:
llama_cpp.llama_grammar_free(self.grammar)
self.grammar = None
def __init__( def __init__(
self, self,
parsed_grammar: "parse_state", parsed_grammar: "parse_state",
) -> None: ) -> None:
"""Initialize the grammar pointer from the parsed state."""
self.parsed_grammar = parsed_grammar
grammar_rules = ( grammar_rules = (
parsed_grammar.c_rules() parsed_grammar.c_rules()
) # type: std.vector[std.vector[llama_grammar_element]] ) # type: std.vector[std.vector[llama_grammar_element]]
@ -69,22 +76,25 @@ class LlamaGrammar:
self.n_rules = c_size_t(grammar_rules.size()) self.n_rules = c_size_t(grammar_rules.size())
self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root")) self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root"))
self.grammar = self.init_grammar() self._grammar = llama_cpp.llama_grammar_init(
self.rules, self.n_rules, self.start_rule_index
)
@classmethod @classmethod
def from_string(cls, grammar: str) -> "LlamaGrammar": def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
parsed_grammar = parse(const_char_p(grammar)) # type: parse_state parsed_grammar = parse(const_char_p(grammar)) # type: parse_state
if parsed_grammar.rules.empty(): if parsed_grammar.rules.empty():
raise ValueError( raise ValueError(
f"{cls.from_string.__name__}: error parsing grammar file: parsed_grammar.rules is empty" f"{cls.from_string.__name__}: error parsing grammar file: parsed_grammar.rules is empty"
) )
if verbose:
print(f"{cls.from_string.__name__} grammar:", file=sys.stderr) print(f"{cls.from_string.__name__} grammar:", file=sys.stderr)
print_grammar(sys.stdout, parsed_grammar) print_grammar(sys.stdout, parsed_grammar)
print(file=sys.stderr) print(file=sys.stderr)
return cls(parsed_grammar) return cls(parsed_grammar)
@classmethod @classmethod
def from_file(cls, file: Union[str, Path]) -> "LlamaGrammar": def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar":
try: try:
with open(file) as f: with open(file) as f:
grammar = f.read() grammar = f.read()
@ -94,14 +104,27 @@ class LlamaGrammar:
) )
if grammar: if grammar:
return cls.from_string(grammar) return cls.from_string(grammar, verbose=verbose)
raise ValueError( raise ValueError(
f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty" f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
) )
def init_grammar(self) -> llama_grammar_p: @property
return llama_cpp.llama_grammar_init( def grammar(self) -> llama_grammar_p:
if self._grammar is None:
raise ValueError(
f"{self.__class__.__name__}.grammar: grammar is freed"
)
return self._grammar
@grammar.setter
def grammar(self, value: Optional[llama_grammar_p]) -> None:
self._grammar = value
def reset(self) -> None:
llama_cpp.llama_grammar_free(self.grammar)
self.grammar = llama_cpp.llama_grammar_init(
self.rules, self.n_rules, self.start_rule_index self.rules, self.n_rules, self.start_rule_index
) )
@ -1217,81 +1240,3 @@ def print_grammar(file: TextIO, state: parse_state) -> None:
f"{print_grammar.__name__}: error printing grammar: {err}", f"{print_grammar.__name__}: error printing grammar: {err}",
file=sys.stderr, file=sys.stderr,
) )
# def convert_to_rules(
# llama_grammar_elements: std.vector[std.vector[llama_grammar_element]],
# ) -> Array[llama_grammar_element_p]:
# """Make an Array object that is used for `llama_grammer_init`"""
# # Step 1: Convert each list to llama_grammar_element array and get pointer
# element_arrays = [
# (llama_grammar_element * len(subvector))(*subvector)
# for subvector in llama_grammar_elements
# ] # type: List[Array[llama_grammar_element]]
# # Step 2: Get pointer of each array
# element_array_pointers = [
# cast(subarray, llama_grammar_element_p) for subarray in element_arrays
# ] # type: List[llama_grammar_element_p]
# # Step 3: Make array of these pointers and get its pointer
# return (llama_grammar_element_p * len(element_array_pointers))(
# *element_array_pointers
# )
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate C++ parser from GBNF grammar"
)
parser.add_argument(
"-g",
"--grammar",
type=str,
default="./vendor/llama.cpp/grammars/json.gbnf",
help="path to GBNF grammar file",
)
args = parser.parse_args()
llama_grammar = LlamaGrammar.from_file(Path(args.grammar))
llama_grammar_ptr = llama_grammar.init_grammar()
# ----- USAGE:
# llama_cpp.llama_sample_grammar(ctx=..., candidates=..., grammar=llama_grammar_p)
# llama_cpp.llama_grammar_accept_token(ctx=..., grammar=llama_grammar_p, token=...)
# ----- SAMPLE OUTPUT:
# main grammar:
# root ::= object
# object ::= [{] ws object_11 [}] ws
# value ::= object | array | string | number | value_6 ws
# array ::= [[] ws array_15 []] ws
# string ::= ["] string_18 ["] ws
# number ::= number_19 number_25 number_29 ws
# value_6 ::= [t] [r] [u] [e] | [f] [a] [l] [s] [e] | [n] [u] [l] [l]
# ws ::= ws_31
# object_8 ::= string [:] ws value object_10
# object_9 ::= [,] ws string [:] ws value
# object_10 ::= object_9 object_10 |
# object_11 ::= object_8 |
# array_12 ::= value array_14
# array_13 ::= [,] ws value
# array_14 ::= array_13 array_14 |
# array_15 ::= array_12 |
# string_16 ::= [^"\] | [\] string_17
# string_17 ::= ["\/bfnrt] | [u] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]
# string_18 ::= string_16 string_18 |
# number_19 ::= number_20 number_21
# number_20 ::= [-] |
# number_21 ::= [0-9] | [1-9] number_22
# number_22 ::= [0-9] number_22 |
# number_23 ::= [.] number_24
# number_24 ::= [0-9] number_24 | [0-9]
# number_25 ::= number_23 |
# number_26 ::= [eE] number_27 number_28
# number_27 ::= [-+] |
# number_28 ::= [0-9] number_28 | [0-9]
# number_29 ::= number_26 |
# ws_30 ::= [ <U+0009><U+000A>] ws
# ws_31 ::= ws_30 |