reset grammar for every generation
This commit is contained in:
parent
418aa83b01
commit
b07713cb9f
2 changed files with 39 additions and 95 deletions
|
@ -364,7 +364,7 @@ class Llama:
|
|||
)
|
||||
if grammar is not None:
|
||||
self.grammar = LlamaGrammar.from_file(
|
||||
grammar
|
||||
grammar, verbose=verbose
|
||||
) # type: Optional[LlamaGrammar]
|
||||
else:
|
||||
self.grammar = None
|
||||
|
@ -723,7 +723,6 @@ class Llama:
|
|||
The generated tokens.
|
||||
"""
|
||||
assert self.ctx is not None
|
||||
|
||||
if reset and len(self._input_ids) > 0:
|
||||
longest_prefix = 0
|
||||
for a, b in zip(self._input_ids, tokens[:-1]):
|
||||
|
@ -741,6 +740,9 @@ class Llama:
|
|||
if reset:
|
||||
self.reset()
|
||||
|
||||
if self.grammar is not None:
|
||||
self.grammar.reset()
|
||||
|
||||
while True:
|
||||
self.eval(tokens)
|
||||
token = self.sample(
|
||||
|
@ -1534,9 +1536,6 @@ class Llama:
|
|||
if self.ctx is not None:
|
||||
llama_cpp.llama_free(self.ctx)
|
||||
self.ctx = None
|
||||
if self.grammar is not None:
|
||||
llama_cpp.llama_grammar_free(self.grammar.grammar)
|
||||
self.grammar = None
|
||||
|
||||
def __getstate__(self):
|
||||
return dict(
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
"""C++ implementation of the llama grammar parser."""
|
||||
# flake8: noqa
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from ctypes import * # type: ignore
|
||||
|
@ -19,7 +18,7 @@ from typing import (
|
|||
overload,
|
||||
)
|
||||
|
||||
import llama_cpp
|
||||
from . import llama_cpp
|
||||
|
||||
# Type aliases
|
||||
llama_grammar_element = llama_cpp.llama_grammar_element
|
||||
|
@ -42,10 +41,18 @@ class LlamaGrammar:
|
|||
"""Keeps reference counts of all the arguments, so that they are not
|
||||
garbage collected by Python."""
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Free the grammar pointer when the object is deleted."""
|
||||
if self.grammar is not None:
|
||||
llama_cpp.llama_grammar_free(self.grammar)
|
||||
self.grammar = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parsed_grammar: "parse_state",
|
||||
) -> None:
|
||||
"""Initialize the grammar pointer from the parsed state."""
|
||||
self.parsed_grammar = parsed_grammar
|
||||
grammar_rules = (
|
||||
parsed_grammar.c_rules()
|
||||
) # type: std.vector[std.vector[llama_grammar_element]]
|
||||
|
@ -69,22 +76,25 @@ class LlamaGrammar:
|
|||
|
||||
self.n_rules = c_size_t(grammar_rules.size())
|
||||
self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root"))
|
||||
self.grammar = self.init_grammar()
|
||||
self._grammar = llama_cpp.llama_grammar_init(
|
||||
self.rules, self.n_rules, self.start_rule_index
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, grammar: str) -> "LlamaGrammar":
|
||||
def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
|
||||
parsed_grammar = parse(const_char_p(grammar)) # type: parse_state
|
||||
if parsed_grammar.rules.empty():
|
||||
raise ValueError(
|
||||
f"{cls.from_string.__name__}: error parsing grammar file: parsed_grammar.rules is empty"
|
||||
)
|
||||
print(f"{cls.from_string.__name__} grammar:", file=sys.stderr)
|
||||
print_grammar(sys.stdout, parsed_grammar)
|
||||
print(file=sys.stderr)
|
||||
if verbose:
|
||||
print(f"{cls.from_string.__name__} grammar:", file=sys.stderr)
|
||||
print_grammar(sys.stdout, parsed_grammar)
|
||||
print(file=sys.stderr)
|
||||
return cls(parsed_grammar)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, file: Union[str, Path]) -> "LlamaGrammar":
|
||||
def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar":
|
||||
try:
|
||||
with open(file) as f:
|
||||
grammar = f.read()
|
||||
|
@ -94,14 +104,27 @@ class LlamaGrammar:
|
|||
)
|
||||
|
||||
if grammar:
|
||||
return cls.from_string(grammar)
|
||||
return cls.from_string(grammar, verbose=verbose)
|
||||
|
||||
raise ValueError(
|
||||
f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
|
||||
)
|
||||
|
||||
def init_grammar(self) -> llama_grammar_p:
|
||||
return llama_cpp.llama_grammar_init(
|
||||
@property
|
||||
def grammar(self) -> llama_grammar_p:
|
||||
if self._grammar is None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__}.grammar: grammar is freed"
|
||||
)
|
||||
return self._grammar
|
||||
|
||||
@grammar.setter
|
||||
def grammar(self, value: Optional[llama_grammar_p]) -> None:
|
||||
self._grammar = value
|
||||
|
||||
def reset(self) -> None:
|
||||
llama_cpp.llama_grammar_free(self.grammar)
|
||||
self.grammar = llama_cpp.llama_grammar_init(
|
||||
self.rules, self.n_rules, self.start_rule_index
|
||||
)
|
||||
|
||||
|
@ -1217,81 +1240,3 @@ def print_grammar(file: TextIO, state: parse_state) -> None:
|
|||
f"{print_grammar.__name__}: error printing grammar: {err}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
|
||||
# def convert_to_rules(
|
||||
# llama_grammar_elements: std.vector[std.vector[llama_grammar_element]],
|
||||
# ) -> Array[llama_grammar_element_p]:
|
||||
# """Make an Array object that is used for `llama_grammer_init`"""
|
||||
|
||||
# # Step 1: Convert each list to llama_grammar_element array and get pointer
|
||||
# element_arrays = [
|
||||
# (llama_grammar_element * len(subvector))(*subvector)
|
||||
# for subvector in llama_grammar_elements
|
||||
# ] # type: List[Array[llama_grammar_element]]
|
||||
|
||||
# # Step 2: Get pointer of each array
|
||||
# element_array_pointers = [
|
||||
# cast(subarray, llama_grammar_element_p) for subarray in element_arrays
|
||||
# ] # type: List[llama_grammar_element_p]
|
||||
|
||||
# # Step 3: Make array of these pointers and get its pointer
|
||||
# return (llama_grammar_element_p * len(element_array_pointers))(
|
||||
# *element_array_pointers
|
||||
# )
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate C++ parser from GBNF grammar"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--grammar",
|
||||
type=str,
|
||||
default="./vendor/llama.cpp/grammars/json.gbnf",
|
||||
help="path to GBNF grammar file",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
llama_grammar = LlamaGrammar.from_file(Path(args.grammar))
|
||||
llama_grammar_ptr = llama_grammar.init_grammar()
|
||||
|
||||
# ----- USAGE:
|
||||
# llama_cpp.llama_sample_grammar(ctx=..., candidates=..., grammar=llama_grammar_p)
|
||||
# llama_cpp.llama_grammar_accept_token(ctx=..., grammar=llama_grammar_p, token=...)
|
||||
|
||||
# ----- SAMPLE OUTPUT:
|
||||
# main grammar:
|
||||
# root ::= object
|
||||
# object ::= [{] ws object_11 [}] ws
|
||||
# value ::= object | array | string | number | value_6 ws
|
||||
# array ::= [[] ws array_15 []] ws
|
||||
# string ::= ["] string_18 ["] ws
|
||||
# number ::= number_19 number_25 number_29 ws
|
||||
# value_6 ::= [t] [r] [u] [e] | [f] [a] [l] [s] [e] | [n] [u] [l] [l]
|
||||
# ws ::= ws_31
|
||||
# object_8 ::= string [:] ws value object_10
|
||||
# object_9 ::= [,] ws string [:] ws value
|
||||
# object_10 ::= object_9 object_10 |
|
||||
# object_11 ::= object_8 |
|
||||
# array_12 ::= value array_14
|
||||
# array_13 ::= [,] ws value
|
||||
# array_14 ::= array_13 array_14 |
|
||||
# array_15 ::= array_12 |
|
||||
# string_16 ::= [^"\] | [\] string_17
|
||||
# string_17 ::= ["\/bfnrt] | [u] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]
|
||||
# string_18 ::= string_16 string_18 |
|
||||
# number_19 ::= number_20 number_21
|
||||
# number_20 ::= [-] |
|
||||
# number_21 ::= [0-9] | [1-9] number_22
|
||||
# number_22 ::= [0-9] number_22 |
|
||||
# number_23 ::= [.] number_24
|
||||
# number_24 ::= [0-9] number_24 | [0-9]
|
||||
# number_25 ::= number_23 |
|
||||
# number_26 ::= [eE] number_27 number_28
|
||||
# number_27 ::= [-+] |
|
||||
# number_28 ::= [0-9] number_28 | [0-9]
|
||||
# number_29 ::= number_26 |
|
||||
# ws_30 ::= [ <U+0009><U+000A>] ws
|
||||
# ws_31 ::= ws_30 |
|
||||
|
|
Loading…
Reference in a new issue