prevent memory access error by llama_grammar_free

This commit is contained in:
c0sogi 2023-08-07 17:02:33 +09:00
parent b07713cb9f
commit 0d7d2031a9

View file

@ -52,33 +52,12 @@ class LlamaGrammar:
parsed_grammar: "parse_state",
) -> None:
"""Initialize the grammar pointer from the parsed state."""
self.parsed_grammar = parsed_grammar
grammar_rules = (
self._grammar_rules = (
parsed_grammar.c_rules()
) # type: std.vector[std.vector[llama_grammar_element]]
# Step 1: Convert each list to llama_grammar_element array and get pointer
self.element_arrays = [
(llama_grammar_element * len(sublist))(*sublist)
for sublist in grammar_rules
] # type: List[Array[llama_grammar_element]]
# Step 2: Get pointer of each array
self.element_array_pointers = [
cast(subarray, llama_grammar_element_p)
for subarray in self.element_arrays
] # type: List[llama_grammar_element_p]
# Step 3: Make array of these pointers and get its pointer
self.rules = (
llama_grammar_element_p * len(self.element_array_pointers)
)(*self.element_array_pointers)
self.n_rules = c_size_t(grammar_rules.size())
self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root"))
self._grammar = llama_cpp.llama_grammar_init(
self.rules, self.n_rules, self.start_rule_index
)
) # type: std.vector[std.vector[LlamaGrammarElement]]
self._n_rules = self._grammar_rules.size() # type: int
self._start_rule_index = parsed_grammar.symbol_ids.at("root") # type: int
self.grammar = self.init()
@classmethod
def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
@ -110,23 +89,45 @@ class LlamaGrammar:
f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
)
@property
def grammar(self) -> llama_grammar_p:
if self._grammar is None:
raise ValueError(
f"{self.__class__.__name__}.grammar: grammar is freed"
)
return self._grammar
def init(self) -> None:
# Step 1: Convert LlamaGrammarElement to llama_grammar_element
self._element_lists = [
[
llama_grammar_element(c_int(elem.type.value), c_uint32(elem.value))
for elem in subvector
]
for subvector in self._grammar_rules
] # type: List[List[llama_grammar_element]]
@grammar.setter
def grammar(self, value: Optional[llama_grammar_p]) -> None:
self._grammar = value
# Step 2: Convert each list to llama_grammar_element array and get pointer
self._element_arrays = [
(llama_grammar_element * len(sublist))(*sublist)
for sublist in self._element_lists
] # type: List[Array[llama_grammar_element]]
# Step 3: Get pointer of each array
self._element_array_pointers = [
cast(subarray, llama_grammar_element_p) for subarray in self._element_arrays
] # type: List[llama_grammar_element_p]
# Step 4: Make array of these pointers and get its pointer
self._rules = (llama_grammar_element_p * len(self._element_array_pointers))(
*self._element_array_pointers
)
self.grammar = llama_cpp.llama_grammar_init(
self._rules, c_size_t(self._n_rules), c_size_t(self._start_rule_index)
)
def reset(self) -> None:
if self.grammar is not None:
llama_cpp.llama_grammar_free(self.grammar)
self.grammar = llama_cpp.llama_grammar_init(
self.rules, self.n_rules, self.start_rule_index
)
self.init()
class LlamaGrammarElement:
def __init__(self, type: "llama_gretype", value: int):
self.type = type
self.value = value # Unicode code point or rule ID
class const_char_p:
@ -182,21 +183,15 @@ class const_char_p:
)
def __eq__(self: Ptr, other: Ptr) -> bool:
assert (
self.value == other.value
), "comparing pointers from different strings"
assert self.value == other.value, "comparing pointers from different strings"
return self.pos == other.pos
def __lt__(self: Ptr, other: Ptr) -> bool:
assert (
self.value == other.value
), "comparing pointers from different strings"
assert self.value == other.value, "comparing pointers from different strings"
return self.pos < other.pos
def __gt__(self: Ptr, other: Ptr) -> bool:
assert (
self.value == other.value
), "comparing pointers from different strings"
assert self.value == other.value, "comparing pointers from different strings"
return self.pos > other.pos
@ -220,9 +215,7 @@ class std:
def _check_version(self):
if self._version != self._vector._version:
raise RuntimeError(
"Iterator used after vector was modified."
)
raise RuntimeError("Iterator used after vector was modified.")
def __iter__(self):
return self
@ -280,16 +273,12 @@ class std:
) -> None:
if new_size > self.size():
if fill_value_factory is None:
raise ValueError(
"A fill value factory function must be provided."
)
raise ValueError("A fill value factory function must be provided.")
self.reserve(new_size, fill_value_factory)
elif new_size < self.size():
self[:] = self[:new_size]
def reserve(
self, capacity: int, fill_value_factory: Callable[[], T]
) -> None:
def reserve(self, capacity: int, fill_value_factory: Callable[[], T]) -> None:
if capacity > self.size():
fill_value = fill_value_factory()
self.extend([fill_value] * (capacity - self.size()))
@ -401,9 +390,7 @@ class std:
for k in keys:
if k >= key:
return self.iterator(self, k)
raise ValueError(
"No key found that is not less than the input key"
)
raise ValueError("No key found that is not less than the input key")
except TypeError:
raise TypeError("Keys of type T cannot be sorted.")
@ -460,9 +447,7 @@ class llama_gretype(Enum):
class parse_state:
def __init__(self):
self.symbol_ids: std.map[str, int] = std.map()
self.rules: std.vector[
std.vector[llama_grammar_element]
] = std.vector()
self.rules: std.vector[std.vector[LlamaGrammarElement]] = std.vector()
# std::vector<const llama_grammar_element *> parse_state::c_rules() {
# std::vector<const llama_grammar_element *> ret;
@ -471,16 +456,16 @@ class parse_state:
# }
# return ret;
# }
def c_rules(self) -> std.vector[std.vector[llama_grammar_element]]:
ret = (
std.vector()
) # type: std.vector[std.vector[llama_grammar_element]]
def c_rules(self) -> std.vector[std.vector[LlamaGrammarElement]]:
ret = std.vector() # type: std.vector[std.vector[LlamaGrammarElement]]
for rule in self.rules:
ret.push_back(rule.data())
return ret
def __repr__(self) -> str:
return f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})"
return (
f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})"
)
# struct llama_grammar {
@ -531,12 +516,12 @@ def generate_symbol_id(state: parse_state, base_name: str) -> int:
def add_rule(
state: parse_state,
rule_id: int,
rule: std.vector[llama_grammar_element],
rule: std.vector[LlamaGrammarElement],
) -> None:
if state.rules.size() <= rule_id:
state.rules.resize(
rule_id + 1,
fill_value_factory=std.vector[llama_grammar_element],
fill_value_factory=std.vector[LlamaGrammarElement],
)
state.rules[rule_id] = rule
@ -575,9 +560,7 @@ def decode_utf8(src: const_char_p) -> Tuple[int, const_char_p]:
# return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
# }
def is_word_char(c: str) -> bool:
return (
("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9")
)
return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9")
# std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
@ -619,9 +602,7 @@ def parse_hex(src: const_char_p, size: int) -> Tuple[int, const_char_p]:
break
pos += 1
if pos != end:
raise RuntimeError(
"expecting " + str(size) + " hex chars at " + str(src)
)
raise RuntimeError("expecting " + str(size) + " hex chars at " + str(src))
return (value, pos)
@ -707,9 +688,7 @@ def parse_name(src: const_char_p) -> const_char_p:
# }
def parse_space(src: const_char_p, newline_ok: bool) -> const_char_p:
pos = const_char_p(src) # type: const_char_p
while pos[0] in (" ", "\t", "#") or (
newline_ok and pos[0] in ("\r", "\n")
):
while pos[0] in (" ", "\t", "#") or (newline_ok and pos[0] in ("\r", "\n")):
if pos[0] == "#":
while pos[0] is not None and pos[0] not in ("\r", "\n"):
pos += 1
@ -728,7 +707,7 @@ def parse_sequence(
state: parse_state,
src: const_char_p,
rule_name: str,
out_elements: std.vector[llama_grammar_element],
out_elements: std.vector[LlamaGrammarElement],
is_nested: bool,
) -> const_char_p:
# size_t last_sym_start = out_elements.size();
@ -753,9 +732,7 @@ def parse_sequence(
char_pair = parse_char(pos) # type: Tuple[int, const_char_p]
pos = char_pair[1]
out_elements.push_back(
llama_grammar_element(
llama_gretype.LLAMA_GRETYPE_CHAR.value, char_pair[0]
)
LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_CHAR, char_pair[0])
)
pos = parse_space(pos + 1, is_nested)
# } else if (*pos == '[') { // char range(s)
@ -763,9 +740,7 @@ def parse_sequence(
# enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
elif pos[0] == "[": # char range(s)
pos += 1
start_type = (
llama_gretype.LLAMA_GRETYPE_CHAR
) # type: llama_gretype
start_type = llama_gretype.LLAMA_GRETYPE_CHAR # type: llama_gretype
# if (*pos == '^') {
# pos++;
# start_type = LLAMA_GRETYPE_CHAR_NOT;
@ -790,9 +765,7 @@ def parse_sequence(
if last_sym_start < out_elements.size()
else start_type
) # type: llama_gretype
out_elements.push_back(
llama_grammar_element(type.value, char_pair[0])
)
out_elements.push_back(LlamaGrammarElement(type, char_pair[0]))
# if (pos[0] == '-' && pos[1] != ']') {
# auto endchar_pair = parse_char(pos + 1);
# pos = endchar_pair.second;
@ -800,13 +773,11 @@ def parse_sequence(
# }
# }
if pos[0] == "-" and pos[1] != "]":
endchar_pair = parse_char(
pos + 1
) # type: Tuple[int, const_char_p]
endchar_pair = parse_char(pos + 1) # type: Tuple[int, const_char_p]
pos = endchar_pair[1]
out_elements.push_back(
llama_grammar_element(
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value,
LlamaGrammarElement(
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER,
endchar_pair[0],
)
)
@ -820,15 +791,11 @@ def parse_sequence(
# out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
elif is_word_char(pos[0]): # rule reference
name_end = parse_name(pos) # type: const_char_p
ref_rule_id = get_symbol_id(
state, pos, name_end - pos
) # type: int
ref_rule_id = get_symbol_id(state, pos, name_end - pos) # type: int
pos = parse_space(name_end, is_nested)
last_sym_start = out_elements.size()
out_elements.push_back(
llama_grammar_element(
llama_gretype.LLAMA_GRETYPE_RULE_REF.value, ref_rule_id
)
LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, ref_rule_id)
)
# } else if (*pos == '(') { // grouping
# // parse nested alternates into synthesized rule
@ -850,9 +817,7 @@ def parse_sequence(
last_sym_start = out_elements.size()
# output reference to synthesized rule
out_elements.push_back(
llama_grammar_element(
llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id
)
LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id)
)
if pos[0] != ")":
raise RuntimeError("expecting ')' at " + str(pos))
@ -863,9 +828,7 @@ def parse_sequence(
# }
elif pos[0] in ("*", "+", "?"): # repetition operator
if last_sym_start == out_elements.size():
raise RuntimeError(
"expecting preceding item to */+/? at " + str(pos)
)
raise RuntimeError("expecting preceding item to */+/? at " + str(pos))
# // apply transformation to previous symbol (last_sym_start to end) according to
# // rewrite rules:
# // S* --> S' ::= S S' |
@ -878,8 +841,8 @@ def parse_sequence(
# sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
sub_rule_id = generate_symbol_id(state, rule_name) # type: int
sub_rule = std.vector[
llama_grammar_element
]() # type: std.vector[llama_grammar_element]
LlamaGrammarElement
]() # type: std.vector[LlamaGrammarElement]
sub_rule.insert(
sub_rule.end(),
out_elements.begin() + last_sym_start,
@ -893,13 +856,11 @@ def parse_sequence(
# sub_rule.push_back({LLAMA_GRETYPE_ALT, 0});
if pos[0] in ("*", "+"):
sub_rule.push_back(
llama_grammar_element(
llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id
LlamaGrammarElement(
llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id
)
)
sub_rule.push_back(
llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT.value, 0)
)
sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0))
# if (*pos == '+') {
# // add preceding symbol as alternate only for '+' (otherwise empty)
# sub_rule.insert(
@ -918,16 +879,12 @@ def parse_sequence(
out_elements.begin() + last_sym_start,
out_elements.end(),
)
sub_rule.push_back(
llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END.value, 0)
)
sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0))
add_rule(state, sub_rule_id, sub_rule)
# in original rule, replace previous symbol with reference to generated rule
out_elements.resize(last_sym_start)
out_elements.push_back(
llama_grammar_element(
llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id
)
LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id)
)
pos = parse_space(pos + 1, is_nested)
# } else {
@ -965,19 +922,13 @@ def parse_alternates(
rule_id: int,
is_nested: bool,
) -> const_char_p:
rule = std.vector() # type: std.vector[llama_grammar_element]
pos = parse_sequence(
state, src, rule_name, rule, is_nested
) # type: const_char_p
rule = std.vector() # type: std.vector[LlamaGrammarElement]
pos = parse_sequence(state, src, rule_name, rule, is_nested) # type: const_char_p
while pos[0] == "|":
rule.push_back(
llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT.value, 0)
)
rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0))
pos = parse_space(pos + 1, True)
pos = parse_sequence(state, pos, rule_name, rule, is_nested)
rule.push_back(
llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END.value, 0)
)
rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0))
add_rule(state, rule_id, rule)
return pos
@ -1017,9 +968,7 @@ def parse_rule(state: parse_state, src: const_char_p) -> const_char_p:
raise RuntimeError("expecting ::= at " + str(pos))
pos = parse_space(pos + 3, True) # type: const_char_p
pos = parse_alternates(
state, pos, name, rule_id, False
) # type: const_char_p
pos = parse_alternates(state, pos, name, rule_id, False) # type: const_char_p
if pos[0] == "\r":
pos += 2 if pos[1] == "\n" else 1
@ -1080,7 +1029,7 @@ def print_grammar_char(file: TextIO, c: int) -> None:
# default: return false;
# }
# }
def is_char_element(elem: llama_grammar_element) -> bool:
def is_char_element(elem: LlamaGrammarElement) -> bool:
return elem.type in (
llama_gretype.LLAMA_GRETYPE_CHAR.value,
llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value,
@ -1097,7 +1046,7 @@ def is_char_element(elem: llama_grammar_element) -> bool:
def print_rule(
file: TextIO,
rule_id: int,
rule: std.vector[llama_grammar_element],
rule: std.vector[LlamaGrammarElement],
symbol_id_names: std.map[int, str],
) -> None:
# if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
@ -1105,13 +1054,9 @@ def print_rule(
# "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
# }
# fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
if (
rule.empty()
or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value
):
if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value:
raise RuntimeError(
"malformed rule, does not end with LLAMA_GRETYPE_END: "
+ str(rule_id)
"malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id)
)
print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ")
# for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
@ -1154,22 +1099,20 @@ def print_rule(
# break;
# }
for i, elem in enumerate(rule[:-1]):
case = elem.type # type: int
if case == llama_gretype.LLAMA_GRETYPE_END.value:
raise RuntimeError(
"unexpected end of rule: " + str(rule_id) + "," + str(i)
)
elif case == llama_gretype.LLAMA_GRETYPE_ALT.value:
case = elem.type # type: llama_gretype
if case is llama_gretype.LLAMA_GRETYPE_END.value:
raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i))
elif case is llama_gretype.LLAMA_GRETYPE_ALT:
print("| ", file=file, end="")
elif case == llama_gretype.LLAMA_GRETYPE_RULE_REF.value:
elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF:
print(f"{symbol_id_names.at(elem.value)} ", file=file, end="")
elif case == llama_gretype.LLAMA_GRETYPE_CHAR.value:
elif case is llama_gretype.LLAMA_GRETYPE_CHAR:
print("[", file=file, end="")
print_grammar_char(file, elem.value)
elif case == llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value:
elif case is llama_gretype.LLAMA_GRETYPE_CHAR_NOT:
print("[^", file=file, end="")
print_grammar_char(file, elem.value)
elif case == llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value:
elif case is llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER:
if i == 0 or not is_char_element(rule[i - 1]):
raise RuntimeError(
"LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: "
@ -1179,7 +1122,7 @@ def print_rule(
)
print("-", file=file, end="")
print_grammar_char(file, elem.value)
elif case == llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value:
elif case is llama_gretype.LLAMA_GRETYPE_CHAR_ALT:
if i == 0 or not is_char_element(rule[i - 1]):
raise RuntimeError(
"LLAMA_GRETYPE_CHAR_ALT without preceding char: "