prevent memory access error by llama_grammar_free
This commit is contained in:
parent
b07713cb9f
commit
0d7d2031a9
1 changed files with 98 additions and 155 deletions
|
@ -52,33 +52,12 @@ class LlamaGrammar:
|
|||
parsed_grammar: "parse_state",
|
||||
) -> None:
|
||||
"""Initialize the grammar pointer from the parsed state."""
|
||||
self.parsed_grammar = parsed_grammar
|
||||
grammar_rules = (
|
||||
self._grammar_rules = (
|
||||
parsed_grammar.c_rules()
|
||||
) # type: std.vector[std.vector[llama_grammar_element]]
|
||||
|
||||
# Step 1: Convert each list to llama_grammar_element array and get pointer
|
||||
self.element_arrays = [
|
||||
(llama_grammar_element * len(sublist))(*sublist)
|
||||
for sublist in grammar_rules
|
||||
] # type: List[Array[llama_grammar_element]]
|
||||
|
||||
# Step 2: Get pointer of each array
|
||||
self.element_array_pointers = [
|
||||
cast(subarray, llama_grammar_element_p)
|
||||
for subarray in self.element_arrays
|
||||
] # type: List[llama_grammar_element_p]
|
||||
|
||||
# Step 3: Make array of these pointers and get its pointer
|
||||
self.rules = (
|
||||
llama_grammar_element_p * len(self.element_array_pointers)
|
||||
)(*self.element_array_pointers)
|
||||
|
||||
self.n_rules = c_size_t(grammar_rules.size())
|
||||
self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root"))
|
||||
self._grammar = llama_cpp.llama_grammar_init(
|
||||
self.rules, self.n_rules, self.start_rule_index
|
||||
)
|
||||
) # type: std.vector[std.vector[LlamaGrammarElement]]
|
||||
self._n_rules = self._grammar_rules.size() # type: int
|
||||
self._start_rule_index = parsed_grammar.symbol_ids.at("root") # type: int
|
||||
self.grammar = self.init()
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
|
||||
|
@ -110,23 +89,45 @@ class LlamaGrammar:
|
|||
f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
|
||||
)
|
||||
|
||||
@property
|
||||
def grammar(self) -> llama_grammar_p:
|
||||
if self._grammar is None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__}.grammar: grammar is freed"
|
||||
)
|
||||
return self._grammar
|
||||
def init(self) -> None:
|
||||
# Step 1: Convert LlamaGrammarElement to llama_grammar_element
|
||||
self._element_lists = [
|
||||
[
|
||||
llama_grammar_element(c_int(elem.type.value), c_uint32(elem.value))
|
||||
for elem in subvector
|
||||
]
|
||||
for subvector in self._grammar_rules
|
||||
] # type: List[List[llama_grammar_element]]
|
||||
|
||||
@grammar.setter
|
||||
def grammar(self, value: Optional[llama_grammar_p]) -> None:
|
||||
self._grammar = value
|
||||
# Step 2: Convert each list to llama_grammar_element array and get pointer
|
||||
self._element_arrays = [
|
||||
(llama_grammar_element * len(sublist))(*sublist)
|
||||
for sublist in self._element_lists
|
||||
] # type: List[Array[llama_grammar_element]]
|
||||
|
||||
# Step 3: Get pointer of each array
|
||||
self._element_array_pointers = [
|
||||
cast(subarray, llama_grammar_element_p) for subarray in self._element_arrays
|
||||
] # type: List[llama_grammar_element_p]
|
||||
|
||||
# Step 4: Make array of these pointers and get its pointer
|
||||
self._rules = (llama_grammar_element_p * len(self._element_array_pointers))(
|
||||
*self._element_array_pointers
|
||||
)
|
||||
self.grammar = llama_cpp.llama_grammar_init(
|
||||
self._rules, c_size_t(self._n_rules), c_size_t(self._start_rule_index)
|
||||
)
|
||||
|
||||
def reset(self) -> None:
|
||||
if self.grammar is not None:
|
||||
llama_cpp.llama_grammar_free(self.grammar)
|
||||
self.grammar = llama_cpp.llama_grammar_init(
|
||||
self.rules, self.n_rules, self.start_rule_index
|
||||
)
|
||||
self.init()
|
||||
|
||||
|
||||
class LlamaGrammarElement:
|
||||
def __init__(self, type: "llama_gretype", value: int):
|
||||
self.type = type
|
||||
self.value = value # Unicode code point or rule ID
|
||||
|
||||
|
||||
class const_char_p:
|
||||
|
@ -182,21 +183,15 @@ class const_char_p:
|
|||
)
|
||||
|
||||
def __eq__(self: Ptr, other: Ptr) -> bool:
|
||||
assert (
|
||||
self.value == other.value
|
||||
), "comparing pointers from different strings"
|
||||
assert self.value == other.value, "comparing pointers from different strings"
|
||||
return self.pos == other.pos
|
||||
|
||||
def __lt__(self: Ptr, other: Ptr) -> bool:
|
||||
assert (
|
||||
self.value == other.value
|
||||
), "comparing pointers from different strings"
|
||||
assert self.value == other.value, "comparing pointers from different strings"
|
||||
return self.pos < other.pos
|
||||
|
||||
def __gt__(self: Ptr, other: Ptr) -> bool:
|
||||
assert (
|
||||
self.value == other.value
|
||||
), "comparing pointers from different strings"
|
||||
assert self.value == other.value, "comparing pointers from different strings"
|
||||
return self.pos > other.pos
|
||||
|
||||
|
||||
|
@ -220,9 +215,7 @@ class std:
|
|||
|
||||
def _check_version(self):
|
||||
if self._version != self._vector._version:
|
||||
raise RuntimeError(
|
||||
"Iterator used after vector was modified."
|
||||
)
|
||||
raise RuntimeError("Iterator used after vector was modified.")
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
@ -280,16 +273,12 @@ class std:
|
|||
) -> None:
|
||||
if new_size > self.size():
|
||||
if fill_value_factory is None:
|
||||
raise ValueError(
|
||||
"A fill value factory function must be provided."
|
||||
)
|
||||
raise ValueError("A fill value factory function must be provided.")
|
||||
self.reserve(new_size, fill_value_factory)
|
||||
elif new_size < self.size():
|
||||
self[:] = self[:new_size]
|
||||
|
||||
def reserve(
|
||||
self, capacity: int, fill_value_factory: Callable[[], T]
|
||||
) -> None:
|
||||
def reserve(self, capacity: int, fill_value_factory: Callable[[], T]) -> None:
|
||||
if capacity > self.size():
|
||||
fill_value = fill_value_factory()
|
||||
self.extend([fill_value] * (capacity - self.size()))
|
||||
|
@ -401,9 +390,7 @@ class std:
|
|||
for k in keys:
|
||||
if k >= key:
|
||||
return self.iterator(self, k)
|
||||
raise ValueError(
|
||||
"No key found that is not less than the input key"
|
||||
)
|
||||
raise ValueError("No key found that is not less than the input key")
|
||||
except TypeError:
|
||||
raise TypeError("Keys of type T cannot be sorted.")
|
||||
|
||||
|
@ -460,9 +447,7 @@ class llama_gretype(Enum):
|
|||
class parse_state:
|
||||
def __init__(self):
|
||||
self.symbol_ids: std.map[str, int] = std.map()
|
||||
self.rules: std.vector[
|
||||
std.vector[llama_grammar_element]
|
||||
] = std.vector()
|
||||
self.rules: std.vector[std.vector[LlamaGrammarElement]] = std.vector()
|
||||
|
||||
# std::vector<const llama_grammar_element *> parse_state::c_rules() {
|
||||
# std::vector<const llama_grammar_element *> ret;
|
||||
|
@ -471,16 +456,16 @@ class parse_state:
|
|||
# }
|
||||
# return ret;
|
||||
# }
|
||||
def c_rules(self) -> std.vector[std.vector[llama_grammar_element]]:
|
||||
ret = (
|
||||
std.vector()
|
||||
) # type: std.vector[std.vector[llama_grammar_element]]
|
||||
def c_rules(self) -> std.vector[std.vector[LlamaGrammarElement]]:
|
||||
ret = std.vector() # type: std.vector[std.vector[LlamaGrammarElement]]
|
||||
for rule in self.rules:
|
||||
ret.push_back(rule.data())
|
||||
return ret
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})"
|
||||
return (
|
||||
f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})"
|
||||
)
|
||||
|
||||
|
||||
# struct llama_grammar {
|
||||
|
@ -531,12 +516,12 @@ def generate_symbol_id(state: parse_state, base_name: str) -> int:
|
|||
def add_rule(
|
||||
state: parse_state,
|
||||
rule_id: int,
|
||||
rule: std.vector[llama_grammar_element],
|
||||
rule: std.vector[LlamaGrammarElement],
|
||||
) -> None:
|
||||
if state.rules.size() <= rule_id:
|
||||
state.rules.resize(
|
||||
rule_id + 1,
|
||||
fill_value_factory=std.vector[llama_grammar_element],
|
||||
fill_value_factory=std.vector[LlamaGrammarElement],
|
||||
)
|
||||
state.rules[rule_id] = rule
|
||||
|
||||
|
@ -575,9 +560,7 @@ def decode_utf8(src: const_char_p) -> Tuple[int, const_char_p]:
|
|||
# return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
|
||||
# }
|
||||
def is_word_char(c: str) -> bool:
|
||||
return (
|
||||
("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9")
|
||||
)
|
||||
return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9")
|
||||
|
||||
|
||||
# std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
|
||||
|
@ -619,9 +602,7 @@ def parse_hex(src: const_char_p, size: int) -> Tuple[int, const_char_p]:
|
|||
break
|
||||
pos += 1
|
||||
if pos != end:
|
||||
raise RuntimeError(
|
||||
"expecting " + str(size) + " hex chars at " + str(src)
|
||||
)
|
||||
raise RuntimeError("expecting " + str(size) + " hex chars at " + str(src))
|
||||
return (value, pos)
|
||||
|
||||
|
||||
|
@ -707,9 +688,7 @@ def parse_name(src: const_char_p) -> const_char_p:
|
|||
# }
|
||||
def parse_space(src: const_char_p, newline_ok: bool) -> const_char_p:
|
||||
pos = const_char_p(src) # type: const_char_p
|
||||
while pos[0] in (" ", "\t", "#") or (
|
||||
newline_ok and pos[0] in ("\r", "\n")
|
||||
):
|
||||
while pos[0] in (" ", "\t", "#") or (newline_ok and pos[0] in ("\r", "\n")):
|
||||
if pos[0] == "#":
|
||||
while pos[0] is not None and pos[0] not in ("\r", "\n"):
|
||||
pos += 1
|
||||
|
@ -728,7 +707,7 @@ def parse_sequence(
|
|||
state: parse_state,
|
||||
src: const_char_p,
|
||||
rule_name: str,
|
||||
out_elements: std.vector[llama_grammar_element],
|
||||
out_elements: std.vector[LlamaGrammarElement],
|
||||
is_nested: bool,
|
||||
) -> const_char_p:
|
||||
# size_t last_sym_start = out_elements.size();
|
||||
|
@ -753,9 +732,7 @@ def parse_sequence(
|
|||
char_pair = parse_char(pos) # type: Tuple[int, const_char_p]
|
||||
pos = char_pair[1]
|
||||
out_elements.push_back(
|
||||
llama_grammar_element(
|
||||
llama_gretype.LLAMA_GRETYPE_CHAR.value, char_pair[0]
|
||||
)
|
||||
LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_CHAR, char_pair[0])
|
||||
)
|
||||
pos = parse_space(pos + 1, is_nested)
|
||||
# } else if (*pos == '[') { // char range(s)
|
||||
|
@ -763,9 +740,7 @@ def parse_sequence(
|
|||
# enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
|
||||
elif pos[0] == "[": # char range(s)
|
||||
pos += 1
|
||||
start_type = (
|
||||
llama_gretype.LLAMA_GRETYPE_CHAR
|
||||
) # type: llama_gretype
|
||||
start_type = llama_gretype.LLAMA_GRETYPE_CHAR # type: llama_gretype
|
||||
# if (*pos == '^') {
|
||||
# pos++;
|
||||
# start_type = LLAMA_GRETYPE_CHAR_NOT;
|
||||
|
@ -790,9 +765,7 @@ def parse_sequence(
|
|||
if last_sym_start < out_elements.size()
|
||||
else start_type
|
||||
) # type: llama_gretype
|
||||
out_elements.push_back(
|
||||
llama_grammar_element(type.value, char_pair[0])
|
||||
)
|
||||
out_elements.push_back(LlamaGrammarElement(type, char_pair[0]))
|
||||
# if (pos[0] == '-' && pos[1] != ']') {
|
||||
# auto endchar_pair = parse_char(pos + 1);
|
||||
# pos = endchar_pair.second;
|
||||
|
@ -800,13 +773,11 @@ def parse_sequence(
|
|||
# }
|
||||
# }
|
||||
if pos[0] == "-" and pos[1] != "]":
|
||||
endchar_pair = parse_char(
|
||||
pos + 1
|
||||
) # type: Tuple[int, const_char_p]
|
||||
endchar_pair = parse_char(pos + 1) # type: Tuple[int, const_char_p]
|
||||
pos = endchar_pair[1]
|
||||
out_elements.push_back(
|
||||
llama_grammar_element(
|
||||
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value,
|
||||
LlamaGrammarElement(
|
||||
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER,
|
||||
endchar_pair[0],
|
||||
)
|
||||
)
|
||||
|
@ -820,15 +791,11 @@ def parse_sequence(
|
|||
# out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
|
||||
elif is_word_char(pos[0]): # rule reference
|
||||
name_end = parse_name(pos) # type: const_char_p
|
||||
ref_rule_id = get_symbol_id(
|
||||
state, pos, name_end - pos
|
||||
) # type: int
|
||||
ref_rule_id = get_symbol_id(state, pos, name_end - pos) # type: int
|
||||
pos = parse_space(name_end, is_nested)
|
||||
last_sym_start = out_elements.size()
|
||||
out_elements.push_back(
|
||||
llama_grammar_element(
|
||||
llama_gretype.LLAMA_GRETYPE_RULE_REF.value, ref_rule_id
|
||||
)
|
||||
LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, ref_rule_id)
|
||||
)
|
||||
# } else if (*pos == '(') { // grouping
|
||||
# // parse nested alternates into synthesized rule
|
||||
|
@ -850,9 +817,7 @@ def parse_sequence(
|
|||
last_sym_start = out_elements.size()
|
||||
# output reference to synthesized rule
|
||||
out_elements.push_back(
|
||||
llama_grammar_element(
|
||||
llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id
|
||||
)
|
||||
LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id)
|
||||
)
|
||||
if pos[0] != ")":
|
||||
raise RuntimeError("expecting ')' at " + str(pos))
|
||||
|
@ -863,9 +828,7 @@ def parse_sequence(
|
|||
# }
|
||||
elif pos[0] in ("*", "+", "?"): # repetition operator
|
||||
if last_sym_start == out_elements.size():
|
||||
raise RuntimeError(
|
||||
"expecting preceding item to */+/? at " + str(pos)
|
||||
)
|
||||
raise RuntimeError("expecting preceding item to */+/? at " + str(pos))
|
||||
# // apply transformation to previous symbol (last_sym_start to end) according to
|
||||
# // rewrite rules:
|
||||
# // S* --> S' ::= S S' |
|
||||
|
@ -878,8 +841,8 @@ def parse_sequence(
|
|||
# sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
|
||||
sub_rule_id = generate_symbol_id(state, rule_name) # type: int
|
||||
sub_rule = std.vector[
|
||||
llama_grammar_element
|
||||
]() # type: std.vector[llama_grammar_element]
|
||||
LlamaGrammarElement
|
||||
]() # type: std.vector[LlamaGrammarElement]
|
||||
sub_rule.insert(
|
||||
sub_rule.end(),
|
||||
out_elements.begin() + last_sym_start,
|
||||
|
@ -893,13 +856,11 @@ def parse_sequence(
|
|||
# sub_rule.push_back({LLAMA_GRETYPE_ALT, 0});
|
||||
if pos[0] in ("*", "+"):
|
||||
sub_rule.push_back(
|
||||
llama_grammar_element(
|
||||
llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id
|
||||
LlamaGrammarElement(
|
||||
llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id
|
||||
)
|
||||
)
|
||||
sub_rule.push_back(
|
||||
llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT.value, 0)
|
||||
)
|
||||
sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0))
|
||||
# if (*pos == '+') {
|
||||
# // add preceding symbol as alternate only for '+' (otherwise empty)
|
||||
# sub_rule.insert(
|
||||
|
@ -918,16 +879,12 @@ def parse_sequence(
|
|||
out_elements.begin() + last_sym_start,
|
||||
out_elements.end(),
|
||||
)
|
||||
sub_rule.push_back(
|
||||
llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END.value, 0)
|
||||
)
|
||||
sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0))
|
||||
add_rule(state, sub_rule_id, sub_rule)
|
||||
# in original rule, replace previous symbol with reference to generated rule
|
||||
out_elements.resize(last_sym_start)
|
||||
out_elements.push_back(
|
||||
llama_grammar_element(
|
||||
llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id
|
||||
)
|
||||
LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id)
|
||||
)
|
||||
pos = parse_space(pos + 1, is_nested)
|
||||
# } else {
|
||||
|
@ -965,19 +922,13 @@ def parse_alternates(
|
|||
rule_id: int,
|
||||
is_nested: bool,
|
||||
) -> const_char_p:
|
||||
rule = std.vector() # type: std.vector[llama_grammar_element]
|
||||
pos = parse_sequence(
|
||||
state, src, rule_name, rule, is_nested
|
||||
) # type: const_char_p
|
||||
rule = std.vector() # type: std.vector[LlamaGrammarElement]
|
||||
pos = parse_sequence(state, src, rule_name, rule, is_nested) # type: const_char_p
|
||||
while pos[0] == "|":
|
||||
rule.push_back(
|
||||
llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT.value, 0)
|
||||
)
|
||||
rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0))
|
||||
pos = parse_space(pos + 1, True)
|
||||
pos = parse_sequence(state, pos, rule_name, rule, is_nested)
|
||||
rule.push_back(
|
||||
llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END.value, 0)
|
||||
)
|
||||
rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0))
|
||||
add_rule(state, rule_id, rule)
|
||||
return pos
|
||||
|
||||
|
@ -1017,9 +968,7 @@ def parse_rule(state: parse_state, src: const_char_p) -> const_char_p:
|
|||
raise RuntimeError("expecting ::= at " + str(pos))
|
||||
|
||||
pos = parse_space(pos + 3, True) # type: const_char_p
|
||||
pos = parse_alternates(
|
||||
state, pos, name, rule_id, False
|
||||
) # type: const_char_p
|
||||
pos = parse_alternates(state, pos, name, rule_id, False) # type: const_char_p
|
||||
|
||||
if pos[0] == "\r":
|
||||
pos += 2 if pos[1] == "\n" else 1
|
||||
|
@ -1080,7 +1029,7 @@ def print_grammar_char(file: TextIO, c: int) -> None:
|
|||
# default: return false;
|
||||
# }
|
||||
# }
|
||||
def is_char_element(elem: llama_grammar_element) -> bool:
|
||||
def is_char_element(elem: LlamaGrammarElement) -> bool:
|
||||
return elem.type in (
|
||||
llama_gretype.LLAMA_GRETYPE_CHAR.value,
|
||||
llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value,
|
||||
|
@ -1097,7 +1046,7 @@ def is_char_element(elem: llama_grammar_element) -> bool:
|
|||
def print_rule(
|
||||
file: TextIO,
|
||||
rule_id: int,
|
||||
rule: std.vector[llama_grammar_element],
|
||||
rule: std.vector[LlamaGrammarElement],
|
||||
symbol_id_names: std.map[int, str],
|
||||
) -> None:
|
||||
# if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
|
||||
|
@ -1105,13 +1054,9 @@ def print_rule(
|
|||
# "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
|
||||
# }
|
||||
# fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
||||
if (
|
||||
rule.empty()
|
||||
or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value
|
||||
):
|
||||
if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value:
|
||||
raise RuntimeError(
|
||||
"malformed rule, does not end with LLAMA_GRETYPE_END: "
|
||||
+ str(rule_id)
|
||||
"malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id)
|
||||
)
|
||||
print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ")
|
||||
# for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
||||
|
@ -1154,22 +1099,20 @@ def print_rule(
|
|||
# break;
|
||||
# }
|
||||
for i, elem in enumerate(rule[:-1]):
|
||||
case = elem.type # type: int
|
||||
if case == llama_gretype.LLAMA_GRETYPE_END.value:
|
||||
raise RuntimeError(
|
||||
"unexpected end of rule: " + str(rule_id) + "," + str(i)
|
||||
)
|
||||
elif case == llama_gretype.LLAMA_GRETYPE_ALT.value:
|
||||
case = elem.type # type: llama_gretype
|
||||
if case is llama_gretype.LLAMA_GRETYPE_END.value:
|
||||
raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i))
|
||||
elif case is llama_gretype.LLAMA_GRETYPE_ALT:
|
||||
print("| ", file=file, end="")
|
||||
elif case == llama_gretype.LLAMA_GRETYPE_RULE_REF.value:
|
||||
elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF:
|
||||
print(f"{symbol_id_names.at(elem.value)} ", file=file, end="")
|
||||
elif case == llama_gretype.LLAMA_GRETYPE_CHAR.value:
|
||||
elif case is llama_gretype.LLAMA_GRETYPE_CHAR:
|
||||
print("[", file=file, end="")
|
||||
print_grammar_char(file, elem.value)
|
||||
elif case == llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value:
|
||||
elif case is llama_gretype.LLAMA_GRETYPE_CHAR_NOT:
|
||||
print("[^", file=file, end="")
|
||||
print_grammar_char(file, elem.value)
|
||||
elif case == llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value:
|
||||
elif case is llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER:
|
||||
if i == 0 or not is_char_element(rule[i - 1]):
|
||||
raise RuntimeError(
|
||||
"LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: "
|
||||
|
@ -1179,7 +1122,7 @@ def print_rule(
|
|||
)
|
||||
print("-", file=file, end="")
|
||||
print_grammar_char(file, elem.value)
|
||||
elif case == llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value:
|
||||
elif case is llama_gretype.LLAMA_GRETYPE_CHAR_ALT:
|
||||
if i == 0 or not is_char_element(rule[i - 1]):
|
||||
raise RuntimeError(
|
||||
"LLAMA_GRETYPE_CHAR_ALT without preceding char: "
|
||||
|
|
Loading…
Reference in a new issue