reset grammar for every generation
This commit is contained in:
parent
418aa83b01
commit
b07713cb9f
2 changed files with 39 additions and 95 deletions
|
@ -364,7 +364,7 @@ class Llama:
|
||||||
)
|
)
|
||||||
if grammar is not None:
|
if grammar is not None:
|
||||||
self.grammar = LlamaGrammar.from_file(
|
self.grammar = LlamaGrammar.from_file(
|
||||||
grammar
|
grammar, verbose=verbose
|
||||||
) # type: Optional[LlamaGrammar]
|
) # type: Optional[LlamaGrammar]
|
||||||
else:
|
else:
|
||||||
self.grammar = None
|
self.grammar = None
|
||||||
|
@ -723,7 +723,6 @@ class Llama:
|
||||||
The generated tokens.
|
The generated tokens.
|
||||||
"""
|
"""
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
|
|
||||||
if reset and len(self._input_ids) > 0:
|
if reset and len(self._input_ids) > 0:
|
||||||
longest_prefix = 0
|
longest_prefix = 0
|
||||||
for a, b in zip(self._input_ids, tokens[:-1]):
|
for a, b in zip(self._input_ids, tokens[:-1]):
|
||||||
|
@ -741,6 +740,9 @@ class Llama:
|
||||||
if reset:
|
if reset:
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
if self.grammar is not None:
|
||||||
|
self.grammar.reset()
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
self.eval(tokens)
|
self.eval(tokens)
|
||||||
token = self.sample(
|
token = self.sample(
|
||||||
|
@ -1534,9 +1536,6 @@ class Llama:
|
||||||
if self.ctx is not None:
|
if self.ctx is not None:
|
||||||
llama_cpp.llama_free(self.ctx)
|
llama_cpp.llama_free(self.ctx)
|
||||||
self.ctx = None
|
self.ctx = None
|
||||||
if self.grammar is not None:
|
|
||||||
llama_cpp.llama_grammar_free(self.grammar.grammar)
|
|
||||||
self.grammar = None
|
|
||||||
|
|
||||||
def __getstate__(self):
|
def __getstate__(self):
|
||||||
return dict(
|
return dict(
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
"""C++ implementation of the llama grammar parser."""
|
"""C++ implementation of the llama grammar parser."""
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
import argparse
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
import sys
|
||||||
from ctypes import * # type: ignore
|
from ctypes import * # type: ignore
|
||||||
|
@ -19,7 +18,7 @@ from typing import (
|
||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
|
|
||||||
import llama_cpp
|
from . import llama_cpp
|
||||||
|
|
||||||
# Type aliases
|
# Type aliases
|
||||||
llama_grammar_element = llama_cpp.llama_grammar_element
|
llama_grammar_element = llama_cpp.llama_grammar_element
|
||||||
|
@ -42,10 +41,18 @@ class LlamaGrammar:
|
||||||
"""Keeps reference counts of all the arguments, so that they are not
|
"""Keeps reference counts of all the arguments, so that they are not
|
||||||
garbage collected by Python."""
|
garbage collected by Python."""
|
||||||
|
|
||||||
|
def __del__(self) -> None:
|
||||||
|
"""Free the grammar pointer when the object is deleted."""
|
||||||
|
if self.grammar is not None:
|
||||||
|
llama_cpp.llama_grammar_free(self.grammar)
|
||||||
|
self.grammar = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parsed_grammar: "parse_state",
|
parsed_grammar: "parse_state",
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Initialize the grammar pointer from the parsed state."""
|
||||||
|
self.parsed_grammar = parsed_grammar
|
||||||
grammar_rules = (
|
grammar_rules = (
|
||||||
parsed_grammar.c_rules()
|
parsed_grammar.c_rules()
|
||||||
) # type: std.vector[std.vector[llama_grammar_element]]
|
) # type: std.vector[std.vector[llama_grammar_element]]
|
||||||
|
@ -69,22 +76,25 @@ class LlamaGrammar:
|
||||||
|
|
||||||
self.n_rules = c_size_t(grammar_rules.size())
|
self.n_rules = c_size_t(grammar_rules.size())
|
||||||
self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root"))
|
self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root"))
|
||||||
self.grammar = self.init_grammar()
|
self._grammar = llama_cpp.llama_grammar_init(
|
||||||
|
self.rules, self.n_rules, self.start_rule_index
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_string(cls, grammar: str) -> "LlamaGrammar":
|
def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
|
||||||
parsed_grammar = parse(const_char_p(grammar)) # type: parse_state
|
parsed_grammar = parse(const_char_p(grammar)) # type: parse_state
|
||||||
if parsed_grammar.rules.empty():
|
if parsed_grammar.rules.empty():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{cls.from_string.__name__}: error parsing grammar file: parsed_grammar.rules is empty"
|
f"{cls.from_string.__name__}: error parsing grammar file: parsed_grammar.rules is empty"
|
||||||
)
|
)
|
||||||
|
if verbose:
|
||||||
print(f"{cls.from_string.__name__} grammar:", file=sys.stderr)
|
print(f"{cls.from_string.__name__} grammar:", file=sys.stderr)
|
||||||
print_grammar(sys.stdout, parsed_grammar)
|
print_grammar(sys.stdout, parsed_grammar)
|
||||||
print(file=sys.stderr)
|
print(file=sys.stderr)
|
||||||
return cls(parsed_grammar)
|
return cls(parsed_grammar)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_file(cls, file: Union[str, Path]) -> "LlamaGrammar":
|
def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar":
|
||||||
try:
|
try:
|
||||||
with open(file) as f:
|
with open(file) as f:
|
||||||
grammar = f.read()
|
grammar = f.read()
|
||||||
|
@ -94,14 +104,27 @@ class LlamaGrammar:
|
||||||
)
|
)
|
||||||
|
|
||||||
if grammar:
|
if grammar:
|
||||||
return cls.from_string(grammar)
|
return cls.from_string(grammar, verbose=verbose)
|
||||||
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
|
f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_grammar(self) -> llama_grammar_p:
|
@property
|
||||||
return llama_cpp.llama_grammar_init(
|
def grammar(self) -> llama_grammar_p:
|
||||||
|
if self._grammar is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__.__name__}.grammar: grammar is freed"
|
||||||
|
)
|
||||||
|
return self._grammar
|
||||||
|
|
||||||
|
@grammar.setter
|
||||||
|
def grammar(self, value: Optional[llama_grammar_p]) -> None:
|
||||||
|
self._grammar = value
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
llama_cpp.llama_grammar_free(self.grammar)
|
||||||
|
self.grammar = llama_cpp.llama_grammar_init(
|
||||||
self.rules, self.n_rules, self.start_rule_index
|
self.rules, self.n_rules, self.start_rule_index
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1217,81 +1240,3 @@ def print_grammar(file: TextIO, state: parse_state) -> None:
|
||||||
f"{print_grammar.__name__}: error printing grammar: {err}",
|
f"{print_grammar.__name__}: error printing grammar: {err}",
|
||||||
file=sys.stderr,
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# def convert_to_rules(
|
|
||||||
# llama_grammar_elements: std.vector[std.vector[llama_grammar_element]],
|
|
||||||
# ) -> Array[llama_grammar_element_p]:
|
|
||||||
# """Make an Array object that is used for `llama_grammer_init`"""
|
|
||||||
|
|
||||||
# # Step 1: Convert each list to llama_grammar_element array and get pointer
|
|
||||||
# element_arrays = [
|
|
||||||
# (llama_grammar_element * len(subvector))(*subvector)
|
|
||||||
# for subvector in llama_grammar_elements
|
|
||||||
# ] # type: List[Array[llama_grammar_element]]
|
|
||||||
|
|
||||||
# # Step 2: Get pointer of each array
|
|
||||||
# element_array_pointers = [
|
|
||||||
# cast(subarray, llama_grammar_element_p) for subarray in element_arrays
|
|
||||||
# ] # type: List[llama_grammar_element_p]
|
|
||||||
|
|
||||||
# # Step 3: Make array of these pointers and get its pointer
|
|
||||||
# return (llama_grammar_element_p * len(element_array_pointers))(
|
|
||||||
# *element_array_pointers
|
|
||||||
# )
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Generate C++ parser from GBNF grammar"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"-g",
|
|
||||||
"--grammar",
|
|
||||||
type=str,
|
|
||||||
default="./vendor/llama.cpp/grammars/json.gbnf",
|
|
||||||
help="path to GBNF grammar file",
|
|
||||||
)
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
llama_grammar = LlamaGrammar.from_file(Path(args.grammar))
|
|
||||||
llama_grammar_ptr = llama_grammar.init_grammar()
|
|
||||||
|
|
||||||
# ----- USAGE:
|
|
||||||
# llama_cpp.llama_sample_grammar(ctx=..., candidates=..., grammar=llama_grammar_p)
|
|
||||||
# llama_cpp.llama_grammar_accept_token(ctx=..., grammar=llama_grammar_p, token=...)
|
|
||||||
|
|
||||||
# ----- SAMPLE OUTPUT:
|
|
||||||
# main grammar:
|
|
||||||
# root ::= object
|
|
||||||
# object ::= [{] ws object_11 [}] ws
|
|
||||||
# value ::= object | array | string | number | value_6 ws
|
|
||||||
# array ::= [[] ws array_15 []] ws
|
|
||||||
# string ::= ["] string_18 ["] ws
|
|
||||||
# number ::= number_19 number_25 number_29 ws
|
|
||||||
# value_6 ::= [t] [r] [u] [e] | [f] [a] [l] [s] [e] | [n] [u] [l] [l]
|
|
||||||
# ws ::= ws_31
|
|
||||||
# object_8 ::= string [:] ws value object_10
|
|
||||||
# object_9 ::= [,] ws string [:] ws value
|
|
||||||
# object_10 ::= object_9 object_10 |
|
|
||||||
# object_11 ::= object_8 |
|
|
||||||
# array_12 ::= value array_14
|
|
||||||
# array_13 ::= [,] ws value
|
|
||||||
# array_14 ::= array_13 array_14 |
|
|
||||||
# array_15 ::= array_12 |
|
|
||||||
# string_16 ::= [^"\] | [\] string_17
|
|
||||||
# string_17 ::= ["\/bfnrt] | [u] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]
|
|
||||||
# string_18 ::= string_16 string_18 |
|
|
||||||
# number_19 ::= number_20 number_21
|
|
||||||
# number_20 ::= [-] |
|
|
||||||
# number_21 ::= [0-9] | [1-9] number_22
|
|
||||||
# number_22 ::= [0-9] number_22 |
|
|
||||||
# number_23 ::= [.] number_24
|
|
||||||
# number_24 ::= [0-9] number_24 | [0-9]
|
|
||||||
# number_25 ::= number_23 |
|
|
||||||
# number_26 ::= [eE] number_27 number_28
|
|
||||||
# number_27 ::= [-+] |
|
|
||||||
# number_28 ::= [0-9] number_28 | [0-9]
|
|
||||||
# number_29 ::= number_26 |
|
|
||||||
# ws_30 ::= [ <U+0009><U+000A>] ws
|
|
||||||
# ws_31 ::= ws_30 |
|
|
||||||
|
|
Loading…
Reference in a new issue