prevent memory access error by llama_grammar_free
This commit is contained in:
parent
b07713cb9f
commit
0d7d2031a9
1 changed files with 98 additions and 155 deletions
|
@ -52,33 +52,12 @@ class LlamaGrammar:
|
||||||
parsed_grammar: "parse_state",
|
parsed_grammar: "parse_state",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the grammar pointer from the parsed state."""
|
"""Initialize the grammar pointer from the parsed state."""
|
||||||
self.parsed_grammar = parsed_grammar
|
self._grammar_rules = (
|
||||||
grammar_rules = (
|
|
||||||
parsed_grammar.c_rules()
|
parsed_grammar.c_rules()
|
||||||
) # type: std.vector[std.vector[llama_grammar_element]]
|
) # type: std.vector[std.vector[LlamaGrammarElement]]
|
||||||
|
self._n_rules = self._grammar_rules.size() # type: int
|
||||||
# Step 1: Convert each list to llama_grammar_element array and get pointer
|
self._start_rule_index = parsed_grammar.symbol_ids.at("root") # type: int
|
||||||
self.element_arrays = [
|
self.grammar = self.init()
|
||||||
(llama_grammar_element * len(sublist))(*sublist)
|
|
||||||
for sublist in grammar_rules
|
|
||||||
] # type: List[Array[llama_grammar_element]]
|
|
||||||
|
|
||||||
# Step 2: Get pointer of each array
|
|
||||||
self.element_array_pointers = [
|
|
||||||
cast(subarray, llama_grammar_element_p)
|
|
||||||
for subarray in self.element_arrays
|
|
||||||
] # type: List[llama_grammar_element_p]
|
|
||||||
|
|
||||||
# Step 3: Make array of these pointers and get its pointer
|
|
||||||
self.rules = (
|
|
||||||
llama_grammar_element_p * len(self.element_array_pointers)
|
|
||||||
)(*self.element_array_pointers)
|
|
||||||
|
|
||||||
self.n_rules = c_size_t(grammar_rules.size())
|
|
||||||
self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root"))
|
|
||||||
self._grammar = llama_cpp.llama_grammar_init(
|
|
||||||
self.rules, self.n_rules, self.start_rule_index
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
|
def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
|
||||||
|
@ -110,23 +89,45 @@ class LlamaGrammar:
|
||||||
f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
|
f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
def init(self) -> None:
|
||||||
def grammar(self) -> llama_grammar_p:
|
# Step 1: Convert LlamaGrammarElement to llama_grammar_element
|
||||||
if self._grammar is None:
|
self._element_lists = [
|
||||||
raise ValueError(
|
[
|
||||||
f"{self.__class__.__name__}.grammar: grammar is freed"
|
llama_grammar_element(c_int(elem.type.value), c_uint32(elem.value))
|
||||||
)
|
for elem in subvector
|
||||||
return self._grammar
|
]
|
||||||
|
for subvector in self._grammar_rules
|
||||||
|
] # type: List[List[llama_grammar_element]]
|
||||||
|
|
||||||
@grammar.setter
|
# Step 2: Convert each list to llama_grammar_element array and get pointer
|
||||||
def grammar(self, value: Optional[llama_grammar_p]) -> None:
|
self._element_arrays = [
|
||||||
self._grammar = value
|
(llama_grammar_element * len(sublist))(*sublist)
|
||||||
|
for sublist in self._element_lists
|
||||||
|
] # type: List[Array[llama_grammar_element]]
|
||||||
|
|
||||||
|
# Step 3: Get pointer of each array
|
||||||
|
self._element_array_pointers = [
|
||||||
|
cast(subarray, llama_grammar_element_p) for subarray in self._element_arrays
|
||||||
|
] # type: List[llama_grammar_element_p]
|
||||||
|
|
||||||
|
# Step 4: Make array of these pointers and get its pointer
|
||||||
|
self._rules = (llama_grammar_element_p * len(self._element_array_pointers))(
|
||||||
|
*self._element_array_pointers
|
||||||
|
)
|
||||||
|
self.grammar = llama_cpp.llama_grammar_init(
|
||||||
|
self._rules, c_size_t(self._n_rules), c_size_t(self._start_rule_index)
|
||||||
|
)
|
||||||
|
|
||||||
def reset(self) -> None:
|
def reset(self) -> None:
|
||||||
llama_cpp.llama_grammar_free(self.grammar)
|
if self.grammar is not None:
|
||||||
self.grammar = llama_cpp.llama_grammar_init(
|
llama_cpp.llama_grammar_free(self.grammar)
|
||||||
self.rules, self.n_rules, self.start_rule_index
|
self.init()
|
||||||
)
|
|
||||||
|
|
||||||
|
class LlamaGrammarElement:
|
||||||
|
def __init__(self, type: "llama_gretype", value: int):
|
||||||
|
self.type = type
|
||||||
|
self.value = value # Unicode code point or rule ID
|
||||||
|
|
||||||
|
|
||||||
class const_char_p:
|
class const_char_p:
|
||||||
|
@ -182,21 +183,15 @@ class const_char_p:
|
||||||
)
|
)
|
||||||
|
|
||||||
def __eq__(self: Ptr, other: Ptr) -> bool:
|
def __eq__(self: Ptr, other: Ptr) -> bool:
|
||||||
assert (
|
assert self.value == other.value, "comparing pointers from different strings"
|
||||||
self.value == other.value
|
|
||||||
), "comparing pointers from different strings"
|
|
||||||
return self.pos == other.pos
|
return self.pos == other.pos
|
||||||
|
|
||||||
def __lt__(self: Ptr, other: Ptr) -> bool:
|
def __lt__(self: Ptr, other: Ptr) -> bool:
|
||||||
assert (
|
assert self.value == other.value, "comparing pointers from different strings"
|
||||||
self.value == other.value
|
|
||||||
), "comparing pointers from different strings"
|
|
||||||
return self.pos < other.pos
|
return self.pos < other.pos
|
||||||
|
|
||||||
def __gt__(self: Ptr, other: Ptr) -> bool:
|
def __gt__(self: Ptr, other: Ptr) -> bool:
|
||||||
assert (
|
assert self.value == other.value, "comparing pointers from different strings"
|
||||||
self.value == other.value
|
|
||||||
), "comparing pointers from different strings"
|
|
||||||
return self.pos > other.pos
|
return self.pos > other.pos
|
||||||
|
|
||||||
|
|
||||||
|
@ -220,9 +215,7 @@ class std:
|
||||||
|
|
||||||
def _check_version(self):
|
def _check_version(self):
|
||||||
if self._version != self._vector._version:
|
if self._version != self._vector._version:
|
||||||
raise RuntimeError(
|
raise RuntimeError("Iterator used after vector was modified.")
|
||||||
"Iterator used after vector was modified."
|
|
||||||
)
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
@ -280,16 +273,12 @@ class std:
|
||||||
) -> None:
|
) -> None:
|
||||||
if new_size > self.size():
|
if new_size > self.size():
|
||||||
if fill_value_factory is None:
|
if fill_value_factory is None:
|
||||||
raise ValueError(
|
raise ValueError("A fill value factory function must be provided.")
|
||||||
"A fill value factory function must be provided."
|
|
||||||
)
|
|
||||||
self.reserve(new_size, fill_value_factory)
|
self.reserve(new_size, fill_value_factory)
|
||||||
elif new_size < self.size():
|
elif new_size < self.size():
|
||||||
self[:] = self[:new_size]
|
self[:] = self[:new_size]
|
||||||
|
|
||||||
def reserve(
|
def reserve(self, capacity: int, fill_value_factory: Callable[[], T]) -> None:
|
||||||
self, capacity: int, fill_value_factory: Callable[[], T]
|
|
||||||
) -> None:
|
|
||||||
if capacity > self.size():
|
if capacity > self.size():
|
||||||
fill_value = fill_value_factory()
|
fill_value = fill_value_factory()
|
||||||
self.extend([fill_value] * (capacity - self.size()))
|
self.extend([fill_value] * (capacity - self.size()))
|
||||||
|
@ -401,9 +390,7 @@ class std:
|
||||||
for k in keys:
|
for k in keys:
|
||||||
if k >= key:
|
if k >= key:
|
||||||
return self.iterator(self, k)
|
return self.iterator(self, k)
|
||||||
raise ValueError(
|
raise ValueError("No key found that is not less than the input key")
|
||||||
"No key found that is not less than the input key"
|
|
||||||
)
|
|
||||||
except TypeError:
|
except TypeError:
|
||||||
raise TypeError("Keys of type T cannot be sorted.")
|
raise TypeError("Keys of type T cannot be sorted.")
|
||||||
|
|
||||||
|
@ -460,9 +447,7 @@ class llama_gretype(Enum):
|
||||||
class parse_state:
|
class parse_state:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.symbol_ids: std.map[str, int] = std.map()
|
self.symbol_ids: std.map[str, int] = std.map()
|
||||||
self.rules: std.vector[
|
self.rules: std.vector[std.vector[LlamaGrammarElement]] = std.vector()
|
||||||
std.vector[llama_grammar_element]
|
|
||||||
] = std.vector()
|
|
||||||
|
|
||||||
# std::vector<const llama_grammar_element *> parse_state::c_rules() {
|
# std::vector<const llama_grammar_element *> parse_state::c_rules() {
|
||||||
# std::vector<const llama_grammar_element *> ret;
|
# std::vector<const llama_grammar_element *> ret;
|
||||||
|
@ -471,16 +456,16 @@ class parse_state:
|
||||||
# }
|
# }
|
||||||
# return ret;
|
# return ret;
|
||||||
# }
|
# }
|
||||||
def c_rules(self) -> std.vector[std.vector[llama_grammar_element]]:
|
def c_rules(self) -> std.vector[std.vector[LlamaGrammarElement]]:
|
||||||
ret = (
|
ret = std.vector() # type: std.vector[std.vector[LlamaGrammarElement]]
|
||||||
std.vector()
|
|
||||||
) # type: std.vector[std.vector[llama_grammar_element]]
|
|
||||||
for rule in self.rules:
|
for rule in self.rules:
|
||||||
ret.push_back(rule.data())
|
ret.push_back(rule.data())
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})"
|
return (
|
||||||
|
f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# struct llama_grammar {
|
# struct llama_grammar {
|
||||||
|
@ -531,12 +516,12 @@ def generate_symbol_id(state: parse_state, base_name: str) -> int:
|
||||||
def add_rule(
|
def add_rule(
|
||||||
state: parse_state,
|
state: parse_state,
|
||||||
rule_id: int,
|
rule_id: int,
|
||||||
rule: std.vector[llama_grammar_element],
|
rule: std.vector[LlamaGrammarElement],
|
||||||
) -> None:
|
) -> None:
|
||||||
if state.rules.size() <= rule_id:
|
if state.rules.size() <= rule_id:
|
||||||
state.rules.resize(
|
state.rules.resize(
|
||||||
rule_id + 1,
|
rule_id + 1,
|
||||||
fill_value_factory=std.vector[llama_grammar_element],
|
fill_value_factory=std.vector[LlamaGrammarElement],
|
||||||
)
|
)
|
||||||
state.rules[rule_id] = rule
|
state.rules[rule_id] = rule
|
||||||
|
|
||||||
|
@ -575,9 +560,7 @@ def decode_utf8(src: const_char_p) -> Tuple[int, const_char_p]:
|
||||||
# return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
|
# return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
|
||||||
# }
|
# }
|
||||||
def is_word_char(c: str) -> bool:
|
def is_word_char(c: str) -> bool:
|
||||||
return (
|
return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9")
|
||||||
("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
|
# std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
|
||||||
|
@ -619,9 +602,7 @@ def parse_hex(src: const_char_p, size: int) -> Tuple[int, const_char_p]:
|
||||||
break
|
break
|
||||||
pos += 1
|
pos += 1
|
||||||
if pos != end:
|
if pos != end:
|
||||||
raise RuntimeError(
|
raise RuntimeError("expecting " + str(size) + " hex chars at " + str(src))
|
||||||
"expecting " + str(size) + " hex chars at " + str(src)
|
|
||||||
)
|
|
||||||
return (value, pos)
|
return (value, pos)
|
||||||
|
|
||||||
|
|
||||||
|
@ -707,9 +688,7 @@ def parse_name(src: const_char_p) -> const_char_p:
|
||||||
# }
|
# }
|
||||||
def parse_space(src: const_char_p, newline_ok: bool) -> const_char_p:
|
def parse_space(src: const_char_p, newline_ok: bool) -> const_char_p:
|
||||||
pos = const_char_p(src) # type: const_char_p
|
pos = const_char_p(src) # type: const_char_p
|
||||||
while pos[0] in (" ", "\t", "#") or (
|
while pos[0] in (" ", "\t", "#") or (newline_ok and pos[0] in ("\r", "\n")):
|
||||||
newline_ok and pos[0] in ("\r", "\n")
|
|
||||||
):
|
|
||||||
if pos[0] == "#":
|
if pos[0] == "#":
|
||||||
while pos[0] is not None and pos[0] not in ("\r", "\n"):
|
while pos[0] is not None and pos[0] not in ("\r", "\n"):
|
||||||
pos += 1
|
pos += 1
|
||||||
|
@ -728,7 +707,7 @@ def parse_sequence(
|
||||||
state: parse_state,
|
state: parse_state,
|
||||||
src: const_char_p,
|
src: const_char_p,
|
||||||
rule_name: str,
|
rule_name: str,
|
||||||
out_elements: std.vector[llama_grammar_element],
|
out_elements: std.vector[LlamaGrammarElement],
|
||||||
is_nested: bool,
|
is_nested: bool,
|
||||||
) -> const_char_p:
|
) -> const_char_p:
|
||||||
# size_t last_sym_start = out_elements.size();
|
# size_t last_sym_start = out_elements.size();
|
||||||
|
@ -753,9 +732,7 @@ def parse_sequence(
|
||||||
char_pair = parse_char(pos) # type: Tuple[int, const_char_p]
|
char_pair = parse_char(pos) # type: Tuple[int, const_char_p]
|
||||||
pos = char_pair[1]
|
pos = char_pair[1]
|
||||||
out_elements.push_back(
|
out_elements.push_back(
|
||||||
llama_grammar_element(
|
LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_CHAR, char_pair[0])
|
||||||
llama_gretype.LLAMA_GRETYPE_CHAR.value, char_pair[0]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
pos = parse_space(pos + 1, is_nested)
|
pos = parse_space(pos + 1, is_nested)
|
||||||
# } else if (*pos == '[') { // char range(s)
|
# } else if (*pos == '[') { // char range(s)
|
||||||
|
@ -763,9 +740,7 @@ def parse_sequence(
|
||||||
# enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
|
# enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
|
||||||
elif pos[0] == "[": # char range(s)
|
elif pos[0] == "[": # char range(s)
|
||||||
pos += 1
|
pos += 1
|
||||||
start_type = (
|
start_type = llama_gretype.LLAMA_GRETYPE_CHAR # type: llama_gretype
|
||||||
llama_gretype.LLAMA_GRETYPE_CHAR
|
|
||||||
) # type: llama_gretype
|
|
||||||
# if (*pos == '^') {
|
# if (*pos == '^') {
|
||||||
# pos++;
|
# pos++;
|
||||||
# start_type = LLAMA_GRETYPE_CHAR_NOT;
|
# start_type = LLAMA_GRETYPE_CHAR_NOT;
|
||||||
|
@ -790,9 +765,7 @@ def parse_sequence(
|
||||||
if last_sym_start < out_elements.size()
|
if last_sym_start < out_elements.size()
|
||||||
else start_type
|
else start_type
|
||||||
) # type: llama_gretype
|
) # type: llama_gretype
|
||||||
out_elements.push_back(
|
out_elements.push_back(LlamaGrammarElement(type, char_pair[0]))
|
||||||
llama_grammar_element(type.value, char_pair[0])
|
|
||||||
)
|
|
||||||
# if (pos[0] == '-' && pos[1] != ']') {
|
# if (pos[0] == '-' && pos[1] != ']') {
|
||||||
# auto endchar_pair = parse_char(pos + 1);
|
# auto endchar_pair = parse_char(pos + 1);
|
||||||
# pos = endchar_pair.second;
|
# pos = endchar_pair.second;
|
||||||
|
@ -800,13 +773,11 @@ def parse_sequence(
|
||||||
# }
|
# }
|
||||||
# }
|
# }
|
||||||
if pos[0] == "-" and pos[1] != "]":
|
if pos[0] == "-" and pos[1] != "]":
|
||||||
endchar_pair = parse_char(
|
endchar_pair = parse_char(pos + 1) # type: Tuple[int, const_char_p]
|
||||||
pos + 1
|
|
||||||
) # type: Tuple[int, const_char_p]
|
|
||||||
pos = endchar_pair[1]
|
pos = endchar_pair[1]
|
||||||
out_elements.push_back(
|
out_elements.push_back(
|
||||||
llama_grammar_element(
|
LlamaGrammarElement(
|
||||||
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value,
|
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER,
|
||||||
endchar_pair[0],
|
endchar_pair[0],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -820,15 +791,11 @@ def parse_sequence(
|
||||||
# out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
# out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
||||||
elif is_word_char(pos[0]): # rule reference
|
elif is_word_char(pos[0]): # rule reference
|
||||||
name_end = parse_name(pos) # type: const_char_p
|
name_end = parse_name(pos) # type: const_char_p
|
||||||
ref_rule_id = get_symbol_id(
|
ref_rule_id = get_symbol_id(state, pos, name_end - pos) # type: int
|
||||||
state, pos, name_end - pos
|
|
||||||
) # type: int
|
|
||||||
pos = parse_space(name_end, is_nested)
|
pos = parse_space(name_end, is_nested)
|
||||||
last_sym_start = out_elements.size()
|
last_sym_start = out_elements.size()
|
||||||
out_elements.push_back(
|
out_elements.push_back(
|
||||||
llama_grammar_element(
|
LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, ref_rule_id)
|
||||||
llama_gretype.LLAMA_GRETYPE_RULE_REF.value, ref_rule_id
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
# } else if (*pos == '(') { // grouping
|
# } else if (*pos == '(') { // grouping
|
||||||
# // parse nested alternates into synthesized rule
|
# // parse nested alternates into synthesized rule
|
||||||
|
@ -850,9 +817,7 @@ def parse_sequence(
|
||||||
last_sym_start = out_elements.size()
|
last_sym_start = out_elements.size()
|
||||||
# output reference to synthesized rule
|
# output reference to synthesized rule
|
||||||
out_elements.push_back(
|
out_elements.push_back(
|
||||||
llama_grammar_element(
|
LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id)
|
||||||
llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if pos[0] != ")":
|
if pos[0] != ")":
|
||||||
raise RuntimeError("expecting ')' at " + str(pos))
|
raise RuntimeError("expecting ')' at " + str(pos))
|
||||||
|
@ -863,9 +828,7 @@ def parse_sequence(
|
||||||
# }
|
# }
|
||||||
elif pos[0] in ("*", "+", "?"): # repetition operator
|
elif pos[0] in ("*", "+", "?"): # repetition operator
|
||||||
if last_sym_start == out_elements.size():
|
if last_sym_start == out_elements.size():
|
||||||
raise RuntimeError(
|
raise RuntimeError("expecting preceding item to */+/? at " + str(pos))
|
||||||
"expecting preceding item to */+/? at " + str(pos)
|
|
||||||
)
|
|
||||||
# // apply transformation to previous symbol (last_sym_start to end) according to
|
# // apply transformation to previous symbol (last_sym_start to end) according to
|
||||||
# // rewrite rules:
|
# // rewrite rules:
|
||||||
# // S* --> S' ::= S S' |
|
# // S* --> S' ::= S S' |
|
||||||
|
@ -878,8 +841,8 @@ def parse_sequence(
|
||||||
# sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
|
# sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
|
||||||
sub_rule_id = generate_symbol_id(state, rule_name) # type: int
|
sub_rule_id = generate_symbol_id(state, rule_name) # type: int
|
||||||
sub_rule = std.vector[
|
sub_rule = std.vector[
|
||||||
llama_grammar_element
|
LlamaGrammarElement
|
||||||
]() # type: std.vector[llama_grammar_element]
|
]() # type: std.vector[LlamaGrammarElement]
|
||||||
sub_rule.insert(
|
sub_rule.insert(
|
||||||
sub_rule.end(),
|
sub_rule.end(),
|
||||||
out_elements.begin() + last_sym_start,
|
out_elements.begin() + last_sym_start,
|
||||||
|
@ -893,13 +856,11 @@ def parse_sequence(
|
||||||
# sub_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
# sub_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
||||||
if pos[0] in ("*", "+"):
|
if pos[0] in ("*", "+"):
|
||||||
sub_rule.push_back(
|
sub_rule.push_back(
|
||||||
llama_grammar_element(
|
LlamaGrammarElement(
|
||||||
llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id
|
llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
sub_rule.push_back(
|
sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0))
|
||||||
llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT.value, 0)
|
|
||||||
)
|
|
||||||
# if (*pos == '+') {
|
# if (*pos == '+') {
|
||||||
# // add preceding symbol as alternate only for '+' (otherwise empty)
|
# // add preceding symbol as alternate only for '+' (otherwise empty)
|
||||||
# sub_rule.insert(
|
# sub_rule.insert(
|
||||||
|
@ -918,16 +879,12 @@ def parse_sequence(
|
||||||
out_elements.begin() + last_sym_start,
|
out_elements.begin() + last_sym_start,
|
||||||
out_elements.end(),
|
out_elements.end(),
|
||||||
)
|
)
|
||||||
sub_rule.push_back(
|
sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0))
|
||||||
llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END.value, 0)
|
|
||||||
)
|
|
||||||
add_rule(state, sub_rule_id, sub_rule)
|
add_rule(state, sub_rule_id, sub_rule)
|
||||||
# in original rule, replace previous symbol with reference to generated rule
|
# in original rule, replace previous symbol with reference to generated rule
|
||||||
out_elements.resize(last_sym_start)
|
out_elements.resize(last_sym_start)
|
||||||
out_elements.push_back(
|
out_elements.push_back(
|
||||||
llama_grammar_element(
|
LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id)
|
||||||
llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
pos = parse_space(pos + 1, is_nested)
|
pos = parse_space(pos + 1, is_nested)
|
||||||
# } else {
|
# } else {
|
||||||
|
@ -965,19 +922,13 @@ def parse_alternates(
|
||||||
rule_id: int,
|
rule_id: int,
|
||||||
is_nested: bool,
|
is_nested: bool,
|
||||||
) -> const_char_p:
|
) -> const_char_p:
|
||||||
rule = std.vector() # type: std.vector[llama_grammar_element]
|
rule = std.vector() # type: std.vector[LlamaGrammarElement]
|
||||||
pos = parse_sequence(
|
pos = parse_sequence(state, src, rule_name, rule, is_nested) # type: const_char_p
|
||||||
state, src, rule_name, rule, is_nested
|
|
||||||
) # type: const_char_p
|
|
||||||
while pos[0] == "|":
|
while pos[0] == "|":
|
||||||
rule.push_back(
|
rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0))
|
||||||
llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT.value, 0)
|
|
||||||
)
|
|
||||||
pos = parse_space(pos + 1, True)
|
pos = parse_space(pos + 1, True)
|
||||||
pos = parse_sequence(state, pos, rule_name, rule, is_nested)
|
pos = parse_sequence(state, pos, rule_name, rule, is_nested)
|
||||||
rule.push_back(
|
rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0))
|
||||||
llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END.value, 0)
|
|
||||||
)
|
|
||||||
add_rule(state, rule_id, rule)
|
add_rule(state, rule_id, rule)
|
||||||
return pos
|
return pos
|
||||||
|
|
||||||
|
@ -1017,9 +968,7 @@ def parse_rule(state: parse_state, src: const_char_p) -> const_char_p:
|
||||||
raise RuntimeError("expecting ::= at " + str(pos))
|
raise RuntimeError("expecting ::= at " + str(pos))
|
||||||
|
|
||||||
pos = parse_space(pos + 3, True) # type: const_char_p
|
pos = parse_space(pos + 3, True) # type: const_char_p
|
||||||
pos = parse_alternates(
|
pos = parse_alternates(state, pos, name, rule_id, False) # type: const_char_p
|
||||||
state, pos, name, rule_id, False
|
|
||||||
) # type: const_char_p
|
|
||||||
|
|
||||||
if pos[0] == "\r":
|
if pos[0] == "\r":
|
||||||
pos += 2 if pos[1] == "\n" else 1
|
pos += 2 if pos[1] == "\n" else 1
|
||||||
|
@ -1080,7 +1029,7 @@ def print_grammar_char(file: TextIO, c: int) -> None:
|
||||||
# default: return false;
|
# default: return false;
|
||||||
# }
|
# }
|
||||||
# }
|
# }
|
||||||
def is_char_element(elem: llama_grammar_element) -> bool:
|
def is_char_element(elem: LlamaGrammarElement) -> bool:
|
||||||
return elem.type in (
|
return elem.type in (
|
||||||
llama_gretype.LLAMA_GRETYPE_CHAR.value,
|
llama_gretype.LLAMA_GRETYPE_CHAR.value,
|
||||||
llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value,
|
llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value,
|
||||||
|
@ -1097,7 +1046,7 @@ def is_char_element(elem: llama_grammar_element) -> bool:
|
||||||
def print_rule(
|
def print_rule(
|
||||||
file: TextIO,
|
file: TextIO,
|
||||||
rule_id: int,
|
rule_id: int,
|
||||||
rule: std.vector[llama_grammar_element],
|
rule: std.vector[LlamaGrammarElement],
|
||||||
symbol_id_names: std.map[int, str],
|
symbol_id_names: std.map[int, str],
|
||||||
) -> None:
|
) -> None:
|
||||||
# if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
|
# if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
|
||||||
|
@ -1105,13 +1054,9 @@ def print_rule(
|
||||||
# "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
|
# "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
|
||||||
# }
|
# }
|
||||||
# fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
# fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
||||||
if (
|
if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value:
|
||||||
rule.empty()
|
|
||||||
or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value
|
|
||||||
):
|
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"malformed rule, does not end with LLAMA_GRETYPE_END: "
|
"malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id)
|
||||||
+ str(rule_id)
|
|
||||||
)
|
)
|
||||||
print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ")
|
print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ")
|
||||||
# for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
# for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
||||||
|
@ -1154,22 +1099,20 @@ def print_rule(
|
||||||
# break;
|
# break;
|
||||||
# }
|
# }
|
||||||
for i, elem in enumerate(rule[:-1]):
|
for i, elem in enumerate(rule[:-1]):
|
||||||
case = elem.type # type: int
|
case = elem.type # type: llama_gretype
|
||||||
if case == llama_gretype.LLAMA_GRETYPE_END.value:
|
if case is llama_gretype.LLAMA_GRETYPE_END.value:
|
||||||
raise RuntimeError(
|
raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i))
|
||||||
"unexpected end of rule: " + str(rule_id) + "," + str(i)
|
elif case is llama_gretype.LLAMA_GRETYPE_ALT:
|
||||||
)
|
|
||||||
elif case == llama_gretype.LLAMA_GRETYPE_ALT.value:
|
|
||||||
print("| ", file=file, end="")
|
print("| ", file=file, end="")
|
||||||
elif case == llama_gretype.LLAMA_GRETYPE_RULE_REF.value:
|
elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF:
|
||||||
print(f"{symbol_id_names.at(elem.value)} ", file=file, end="")
|
print(f"{symbol_id_names.at(elem.value)} ", file=file, end="")
|
||||||
elif case == llama_gretype.LLAMA_GRETYPE_CHAR.value:
|
elif case is llama_gretype.LLAMA_GRETYPE_CHAR:
|
||||||
print("[", file=file, end="")
|
print("[", file=file, end="")
|
||||||
print_grammar_char(file, elem.value)
|
print_grammar_char(file, elem.value)
|
||||||
elif case == llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value:
|
elif case is llama_gretype.LLAMA_GRETYPE_CHAR_NOT:
|
||||||
print("[^", file=file, end="")
|
print("[^", file=file, end="")
|
||||||
print_grammar_char(file, elem.value)
|
print_grammar_char(file, elem.value)
|
||||||
elif case == llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value:
|
elif case is llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
||||||
if i == 0 or not is_char_element(rule[i - 1]):
|
if i == 0 or not is_char_element(rule[i - 1]):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: "
|
"LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: "
|
||||||
|
@ -1179,7 +1122,7 @@ def print_rule(
|
||||||
)
|
)
|
||||||
print("-", file=file, end="")
|
print("-", file=file, end="")
|
||||||
print_grammar_char(file, elem.value)
|
print_grammar_char(file, elem.value)
|
||||||
elif case == llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value:
|
elif case is llama_gretype.LLAMA_GRETYPE_CHAR_ALT:
|
||||||
if i == 0 or not is_char_element(rule[i - 1]):
|
if i == 0 or not is_char_element(rule[i - 1]):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"LLAMA_GRETYPE_CHAR_ALT without preceding char: "
|
"LLAMA_GRETYPE_CHAR_ALT without preceding char: "
|
||||||
|
|
Loading…
Reference in a new issue