reset grammar for every generation

This commit is contained in:
c0sogi 2023-08-07 15:16:25 +09:00
parent 418aa83b01
commit b07713cb9f
2 changed files with 39 additions and 95 deletions

View file

@ -364,7 +364,7 @@ class Llama:
)
if grammar is not None:
self.grammar = LlamaGrammar.from_file(
grammar
grammar, verbose=verbose
) # type: Optional[LlamaGrammar]
else:
self.grammar = None
@ -723,7 +723,6 @@ class Llama:
The generated tokens.
"""
assert self.ctx is not None
if reset and len(self._input_ids) > 0:
longest_prefix = 0
for a, b in zip(self._input_ids, tokens[:-1]):
@ -741,6 +740,9 @@ class Llama:
if reset:
self.reset()
if self.grammar is not None:
self.grammar.reset()
while True:
self.eval(tokens)
token = self.sample(
@ -1534,9 +1536,6 @@ class Llama:
if self.ctx is not None:
llama_cpp.llama_free(self.ctx)
self.ctx = None
if self.grammar is not None:
llama_cpp.llama_grammar_free(self.grammar.grammar)
self.grammar = None
def __getstate__(self):
return dict(

View file

@ -1,6 +1,5 @@
"""C++ implementation of the llama grammar parser."""
# flake8: noqa
import argparse
from pathlib import Path
import sys
from ctypes import * # type: ignore
@ -19,7 +18,7 @@ from typing import (
overload,
)
import llama_cpp
from . import llama_cpp
# Type aliases
llama_grammar_element = llama_cpp.llama_grammar_element
@ -42,10 +41,18 @@ class LlamaGrammar:
"""Keeps reference counts of all the arguments, so that they are not
garbage collected by Python."""
def __del__(self) -> None:
"""Free the grammar pointer when the object is deleted."""
if self.grammar is not None:
llama_cpp.llama_grammar_free(self.grammar)
self.grammar = None
def __init__(
self,
parsed_grammar: "parse_state",
) -> None:
"""Initialize the grammar pointer from the parsed state."""
self.parsed_grammar = parsed_grammar
grammar_rules = (
parsed_grammar.c_rules()
) # type: std.vector[std.vector[llama_grammar_element]]
@ -69,22 +76,25 @@ class LlamaGrammar:
self.n_rules = c_size_t(grammar_rules.size())
self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root"))
self.grammar = self.init_grammar()
self._grammar = llama_cpp.llama_grammar_init(
self.rules, self.n_rules, self.start_rule_index
)
@classmethod
def from_string(cls, grammar: str) -> "LlamaGrammar":
def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
parsed_grammar = parse(const_char_p(grammar)) # type: parse_state
if parsed_grammar.rules.empty():
raise ValueError(
f"{cls.from_string.__name__}: error parsing grammar file: parsed_grammar.rules is empty"
)
if verbose:
print(f"{cls.from_string.__name__} grammar:", file=sys.stderr)
print_grammar(sys.stdout, parsed_grammar)
print(file=sys.stderr)
return cls(parsed_grammar)
@classmethod
def from_file(cls, file: Union[str, Path]) -> "LlamaGrammar":
def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar":
try:
with open(file) as f:
grammar = f.read()
@ -94,14 +104,27 @@ class LlamaGrammar:
)
if grammar:
return cls.from_string(grammar)
return cls.from_string(grammar, verbose=verbose)
raise ValueError(
f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
)
def init_grammar(self) -> llama_grammar_p:
return llama_cpp.llama_grammar_init(
@property
def grammar(self) -> llama_grammar_p:
if self._grammar is None:
raise ValueError(
f"{self.__class__.__name__}.grammar: grammar is freed"
)
return self._grammar
@grammar.setter
def grammar(self, value: Optional[llama_grammar_p]) -> None:
self._grammar = value
def reset(self) -> None:
llama_cpp.llama_grammar_free(self.grammar)
self.grammar = llama_cpp.llama_grammar_init(
self.rules, self.n_rules, self.start_rule_index
)
@ -1217,81 +1240,3 @@ def print_grammar(file: TextIO, state: parse_state) -> None:
f"{print_grammar.__name__}: error printing grammar: {err}",
file=sys.stderr,
)
# def convert_to_rules(
# llama_grammar_elements: std.vector[std.vector[llama_grammar_element]],
# ) -> Array[llama_grammar_element_p]:
# """Make an Array object that is used for `llama_grammer_init`"""
# # Step 1: Convert each list to llama_grammar_element array and get pointer
# element_arrays = [
# (llama_grammar_element * len(subvector))(*subvector)
# for subvector in llama_grammar_elements
# ] # type: List[Array[llama_grammar_element]]
# # Step 2: Get pointer of each array
# element_array_pointers = [
# cast(subarray, llama_grammar_element_p) for subarray in element_arrays
# ] # type: List[llama_grammar_element_p]
# # Step 3: Make array of these pointers and get its pointer
# return (llama_grammar_element_p * len(element_array_pointers))(
# *element_array_pointers
# )
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate C++ parser from GBNF grammar"
)
parser.add_argument(
"-g",
"--grammar",
type=str,
default="./vendor/llama.cpp/grammars/json.gbnf",
help="path to GBNF grammar file",
)
args = parser.parse_args()
llama_grammar = LlamaGrammar.from_file(Path(args.grammar))
llama_grammar_ptr = llama_grammar.init_grammar()
# ----- USAGE:
# llama_cpp.llama_sample_grammar(ctx=..., candidates=..., grammar=llama_grammar_p)
# llama_cpp.llama_grammar_accept_token(ctx=..., grammar=llama_grammar_p, token=...)
# ----- SAMPLE OUTPUT:
# main grammar:
# root ::= object
# object ::= [{] ws object_11 [}] ws
# value ::= object | array | string | number | value_6 ws
# array ::= [[] ws array_15 []] ws
# string ::= ["] string_18 ["] ws
# number ::= number_19 number_25 number_29 ws
# value_6 ::= [t] [r] [u] [e] | [f] [a] [l] [s] [e] | [n] [u] [l] [l]
# ws ::= ws_31
# object_8 ::= string [:] ws value object_10
# object_9 ::= [,] ws string [:] ws value
# object_10 ::= object_9 object_10 |
# object_11 ::= object_8 |
# array_12 ::= value array_14
# array_13 ::= [,] ws value
# array_14 ::= array_13 array_14 |
# array_15 ::= array_12 |
# string_16 ::= [^"\] | [\] string_17
# string_17 ::= ["\/bfnrt] | [u] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]
# string_18 ::= string_16 string_18 |
# number_19 ::= number_20 number_21
# number_20 ::= [-] |
# number_21 ::= [0-9] | [1-9] number_22
# number_22 ::= [0-9] number_22 |
# number_23 ::= [.] number_24
# number_24 ::= [0-9] number_24 | [0-9]
# number_25 ::= number_23 |
# number_26 ::= [eE] number_27 number_28
# number_27 ::= [-+] |
# number_28 ::= [0-9] number_28 | [0-9]
# number_29 ::= number_26 |
# ws_30 ::= [ <U+0009><U+000A>] ws
# ws_31 ::= ws_30 |