prevent memory access error by llama_grammar_free

This commit is contained in:
c0sogi 2023-08-07 17:02:33 +09:00
parent b07713cb9f
commit 0d7d2031a9

View file

@ -52,33 +52,12 @@ class LlamaGrammar:
parsed_grammar: "parse_state", parsed_grammar: "parse_state",
) -> None: ) -> None:
"""Initialize the grammar pointer from the parsed state.""" """Initialize the grammar pointer from the parsed state."""
self.parsed_grammar = parsed_grammar self._grammar_rules = (
grammar_rules = (
parsed_grammar.c_rules() parsed_grammar.c_rules()
) # type: std.vector[std.vector[llama_grammar_element]] ) # type: std.vector[std.vector[LlamaGrammarElement]]
self._n_rules = self._grammar_rules.size() # type: int
# Step 1: Convert each list to llama_grammar_element array and get pointer self._start_rule_index = parsed_grammar.symbol_ids.at("root") # type: int
self.element_arrays = [ self.grammar = self.init()
(llama_grammar_element * len(sublist))(*sublist)
for sublist in grammar_rules
] # type: List[Array[llama_grammar_element]]
# Step 2: Get pointer of each array
self.element_array_pointers = [
cast(subarray, llama_grammar_element_p)
for subarray in self.element_arrays
] # type: List[llama_grammar_element_p]
# Step 3: Make array of these pointers and get its pointer
self.rules = (
llama_grammar_element_p * len(self.element_array_pointers)
)(*self.element_array_pointers)
self.n_rules = c_size_t(grammar_rules.size())
self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root"))
self._grammar = llama_cpp.llama_grammar_init(
self.rules, self.n_rules, self.start_rule_index
)
@classmethod @classmethod
def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar": def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar":
@ -110,23 +89,45 @@ class LlamaGrammar:
f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty" f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty"
) )
@property def init(self) -> None:
def grammar(self) -> llama_grammar_p: # Step 1: Convert LlamaGrammarElement to llama_grammar_element
if self._grammar is None: self._element_lists = [
raise ValueError( [
f"{self.__class__.__name__}.grammar: grammar is freed" llama_grammar_element(c_int(elem.type.value), c_uint32(elem.value))
) for elem in subvector
return self._grammar ]
for subvector in self._grammar_rules
] # type: List[List[llama_grammar_element]]
@grammar.setter # Step 2: Convert each list to llama_grammar_element array and get pointer
def grammar(self, value: Optional[llama_grammar_p]) -> None: self._element_arrays = [
self._grammar = value (llama_grammar_element * len(sublist))(*sublist)
for sublist in self._element_lists
] # type: List[Array[llama_grammar_element]]
# Step 3: Get pointer of each array
self._element_array_pointers = [
cast(subarray, llama_grammar_element_p) for subarray in self._element_arrays
] # type: List[llama_grammar_element_p]
# Step 4: Make array of these pointers and get its pointer
self._rules = (llama_grammar_element_p * len(self._element_array_pointers))(
*self._element_array_pointers
)
self.grammar = llama_cpp.llama_grammar_init(
self._rules, c_size_t(self._n_rules), c_size_t(self._start_rule_index)
)
def reset(self) -> None: def reset(self) -> None:
if self.grammar is not None:
llama_cpp.llama_grammar_free(self.grammar) llama_cpp.llama_grammar_free(self.grammar)
self.grammar = llama_cpp.llama_grammar_init( self.init()
self.rules, self.n_rules, self.start_rule_index
)
class LlamaGrammarElement:
def __init__(self, type: "llama_gretype", value: int):
self.type = type
self.value = value # Unicode code point or rule ID
class const_char_p: class const_char_p:
@ -182,21 +183,15 @@ class const_char_p:
) )
def __eq__(self: Ptr, other: Ptr) -> bool: def __eq__(self: Ptr, other: Ptr) -> bool:
assert ( assert self.value == other.value, "comparing pointers from different strings"
self.value == other.value
), "comparing pointers from different strings"
return self.pos == other.pos return self.pos == other.pos
def __lt__(self: Ptr, other: Ptr) -> bool: def __lt__(self: Ptr, other: Ptr) -> bool:
assert ( assert self.value == other.value, "comparing pointers from different strings"
self.value == other.value
), "comparing pointers from different strings"
return self.pos < other.pos return self.pos < other.pos
def __gt__(self: Ptr, other: Ptr) -> bool: def __gt__(self: Ptr, other: Ptr) -> bool:
assert ( assert self.value == other.value, "comparing pointers from different strings"
self.value == other.value
), "comparing pointers from different strings"
return self.pos > other.pos return self.pos > other.pos
@ -220,9 +215,7 @@ class std:
def _check_version(self): def _check_version(self):
if self._version != self._vector._version: if self._version != self._vector._version:
raise RuntimeError( raise RuntimeError("Iterator used after vector was modified.")
"Iterator used after vector was modified."
)
def __iter__(self): def __iter__(self):
return self return self
@ -280,16 +273,12 @@ class std:
) -> None: ) -> None:
if new_size > self.size(): if new_size > self.size():
if fill_value_factory is None: if fill_value_factory is None:
raise ValueError( raise ValueError("A fill value factory function must be provided.")
"A fill value factory function must be provided."
)
self.reserve(new_size, fill_value_factory) self.reserve(new_size, fill_value_factory)
elif new_size < self.size(): elif new_size < self.size():
self[:] = self[:new_size] self[:] = self[:new_size]
def reserve( def reserve(self, capacity: int, fill_value_factory: Callable[[], T]) -> None:
self, capacity: int, fill_value_factory: Callable[[], T]
) -> None:
if capacity > self.size(): if capacity > self.size():
fill_value = fill_value_factory() fill_value = fill_value_factory()
self.extend([fill_value] * (capacity - self.size())) self.extend([fill_value] * (capacity - self.size()))
@ -401,9 +390,7 @@ class std:
for k in keys: for k in keys:
if k >= key: if k >= key:
return self.iterator(self, k) return self.iterator(self, k)
raise ValueError( raise ValueError("No key found that is not less than the input key")
"No key found that is not less than the input key"
)
except TypeError: except TypeError:
raise TypeError("Keys of type T cannot be sorted.") raise TypeError("Keys of type T cannot be sorted.")
@ -460,9 +447,7 @@ class llama_gretype(Enum):
class parse_state: class parse_state:
def __init__(self): def __init__(self):
self.symbol_ids: std.map[str, int] = std.map() self.symbol_ids: std.map[str, int] = std.map()
self.rules: std.vector[ self.rules: std.vector[std.vector[LlamaGrammarElement]] = std.vector()
std.vector[llama_grammar_element]
] = std.vector()
# std::vector<const llama_grammar_element *> parse_state::c_rules() { # std::vector<const llama_grammar_element *> parse_state::c_rules() {
# std::vector<const llama_grammar_element *> ret; # std::vector<const llama_grammar_element *> ret;
@ -471,16 +456,16 @@ class parse_state:
# } # }
# return ret; # return ret;
# } # }
def c_rules(self) -> std.vector[std.vector[llama_grammar_element]]: def c_rules(self) -> std.vector[std.vector[LlamaGrammarElement]]:
ret = ( ret = std.vector() # type: std.vector[std.vector[LlamaGrammarElement]]
std.vector()
) # type: std.vector[std.vector[llama_grammar_element]]
for rule in self.rules: for rule in self.rules:
ret.push_back(rule.data()) ret.push_back(rule.data())
return ret return ret
def __repr__(self) -> str: def __repr__(self) -> str:
return f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})" return (
f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})"
)
# struct llama_grammar { # struct llama_grammar {
@ -531,12 +516,12 @@ def generate_symbol_id(state: parse_state, base_name: str) -> int:
def add_rule( def add_rule(
state: parse_state, state: parse_state,
rule_id: int, rule_id: int,
rule: std.vector[llama_grammar_element], rule: std.vector[LlamaGrammarElement],
) -> None: ) -> None:
if state.rules.size() <= rule_id: if state.rules.size() <= rule_id:
state.rules.resize( state.rules.resize(
rule_id + 1, rule_id + 1,
fill_value_factory=std.vector[llama_grammar_element], fill_value_factory=std.vector[LlamaGrammarElement],
) )
state.rules[rule_id] = rule state.rules[rule_id] = rule
@ -575,9 +560,7 @@ def decode_utf8(src: const_char_p) -> Tuple[int, const_char_p]:
# return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); # return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9');
# } # }
def is_word_char(c: str) -> bool: def is_word_char(c: str) -> bool:
return ( return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9")
("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9")
)
# std::pair<uint32_t, const char *> parse_hex(const char * src, int size) { # std::pair<uint32_t, const char *> parse_hex(const char * src, int size) {
@ -619,9 +602,7 @@ def parse_hex(src: const_char_p, size: int) -> Tuple[int, const_char_p]:
break break
pos += 1 pos += 1
if pos != end: if pos != end:
raise RuntimeError( raise RuntimeError("expecting " + str(size) + " hex chars at " + str(src))
"expecting " + str(size) + " hex chars at " + str(src)
)
return (value, pos) return (value, pos)
@ -707,9 +688,7 @@ def parse_name(src: const_char_p) -> const_char_p:
# } # }
def parse_space(src: const_char_p, newline_ok: bool) -> const_char_p: def parse_space(src: const_char_p, newline_ok: bool) -> const_char_p:
pos = const_char_p(src) # type: const_char_p pos = const_char_p(src) # type: const_char_p
while pos[0] in (" ", "\t", "#") or ( while pos[0] in (" ", "\t", "#") or (newline_ok and pos[0] in ("\r", "\n")):
newline_ok and pos[0] in ("\r", "\n")
):
if pos[0] == "#": if pos[0] == "#":
while pos[0] is not None and pos[0] not in ("\r", "\n"): while pos[0] is not None and pos[0] not in ("\r", "\n"):
pos += 1 pos += 1
@ -728,7 +707,7 @@ def parse_sequence(
state: parse_state, state: parse_state,
src: const_char_p, src: const_char_p,
rule_name: str, rule_name: str,
out_elements: std.vector[llama_grammar_element], out_elements: std.vector[LlamaGrammarElement],
is_nested: bool, is_nested: bool,
) -> const_char_p: ) -> const_char_p:
# size_t last_sym_start = out_elements.size(); # size_t last_sym_start = out_elements.size();
@ -753,9 +732,7 @@ def parse_sequence(
char_pair = parse_char(pos) # type: Tuple[int, const_char_p] char_pair = parse_char(pos) # type: Tuple[int, const_char_p]
pos = char_pair[1] pos = char_pair[1]
out_elements.push_back( out_elements.push_back(
llama_grammar_element( LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_CHAR, char_pair[0])
llama_gretype.LLAMA_GRETYPE_CHAR.value, char_pair[0]
)
) )
pos = parse_space(pos + 1, is_nested) pos = parse_space(pos + 1, is_nested)
# } else if (*pos == '[') { // char range(s) # } else if (*pos == '[') { // char range(s)
@ -763,9 +740,7 @@ def parse_sequence(
# enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; # enum llama_gretype start_type = LLAMA_GRETYPE_CHAR;
elif pos[0] == "[": # char range(s) elif pos[0] == "[": # char range(s)
pos += 1 pos += 1
start_type = ( start_type = llama_gretype.LLAMA_GRETYPE_CHAR # type: llama_gretype
llama_gretype.LLAMA_GRETYPE_CHAR
) # type: llama_gretype
# if (*pos == '^') { # if (*pos == '^') {
# pos++; # pos++;
# start_type = LLAMA_GRETYPE_CHAR_NOT; # start_type = LLAMA_GRETYPE_CHAR_NOT;
@ -790,9 +765,7 @@ def parse_sequence(
if last_sym_start < out_elements.size() if last_sym_start < out_elements.size()
else start_type else start_type
) # type: llama_gretype ) # type: llama_gretype
out_elements.push_back( out_elements.push_back(LlamaGrammarElement(type, char_pair[0]))
llama_grammar_element(type.value, char_pair[0])
)
# if (pos[0] == '-' && pos[1] != ']') { # if (pos[0] == '-' && pos[1] != ']') {
# auto endchar_pair = parse_char(pos + 1); # auto endchar_pair = parse_char(pos + 1);
# pos = endchar_pair.second; # pos = endchar_pair.second;
@ -800,13 +773,11 @@ def parse_sequence(
# } # }
# } # }
if pos[0] == "-" and pos[1] != "]": if pos[0] == "-" and pos[1] != "]":
endchar_pair = parse_char( endchar_pair = parse_char(pos + 1) # type: Tuple[int, const_char_p]
pos + 1
) # type: Tuple[int, const_char_p]
pos = endchar_pair[1] pos = endchar_pair[1]
out_elements.push_back( out_elements.push_back(
llama_grammar_element( LlamaGrammarElement(
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value, llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER,
endchar_pair[0], endchar_pair[0],
) )
) )
@ -820,15 +791,11 @@ def parse_sequence(
# out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id});
elif is_word_char(pos[0]): # rule reference elif is_word_char(pos[0]): # rule reference
name_end = parse_name(pos) # type: const_char_p name_end = parse_name(pos) # type: const_char_p
ref_rule_id = get_symbol_id( ref_rule_id = get_symbol_id(state, pos, name_end - pos) # type: int
state, pos, name_end - pos
) # type: int
pos = parse_space(name_end, is_nested) pos = parse_space(name_end, is_nested)
last_sym_start = out_elements.size() last_sym_start = out_elements.size()
out_elements.push_back( out_elements.push_back(
llama_grammar_element( LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, ref_rule_id)
llama_gretype.LLAMA_GRETYPE_RULE_REF.value, ref_rule_id
)
) )
# } else if (*pos == '(') { // grouping # } else if (*pos == '(') { // grouping
# // parse nested alternates into synthesized rule # // parse nested alternates into synthesized rule
@ -850,9 +817,7 @@ def parse_sequence(
last_sym_start = out_elements.size() last_sym_start = out_elements.size()
# output reference to synthesized rule # output reference to synthesized rule
out_elements.push_back( out_elements.push_back(
llama_grammar_element( LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id)
llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id
)
) )
if pos[0] != ")": if pos[0] != ")":
raise RuntimeError("expecting ')' at " + str(pos)) raise RuntimeError("expecting ')' at " + str(pos))
@ -863,9 +828,7 @@ def parse_sequence(
# } # }
elif pos[0] in ("*", "+", "?"): # repetition operator elif pos[0] in ("*", "+", "?"): # repetition operator
if last_sym_start == out_elements.size(): if last_sym_start == out_elements.size():
raise RuntimeError( raise RuntimeError("expecting preceding item to */+/? at " + str(pos))
"expecting preceding item to */+/? at " + str(pos)
)
# // apply transformation to previous symbol (last_sym_start to end) according to # // apply transformation to previous symbol (last_sym_start to end) according to
# // rewrite rules: # // rewrite rules:
# // S* --> S' ::= S S' | # // S* --> S' ::= S S' |
@ -878,8 +841,8 @@ def parse_sequence(
# sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); # sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end());
sub_rule_id = generate_symbol_id(state, rule_name) # type: int sub_rule_id = generate_symbol_id(state, rule_name) # type: int
sub_rule = std.vector[ sub_rule = std.vector[
llama_grammar_element LlamaGrammarElement
]() # type: std.vector[llama_grammar_element] ]() # type: std.vector[LlamaGrammarElement]
sub_rule.insert( sub_rule.insert(
sub_rule.end(), sub_rule.end(),
out_elements.begin() + last_sym_start, out_elements.begin() + last_sym_start,
@ -893,13 +856,11 @@ def parse_sequence(
# sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); # sub_rule.push_back({LLAMA_GRETYPE_ALT, 0});
if pos[0] in ("*", "+"): if pos[0] in ("*", "+"):
sub_rule.push_back( sub_rule.push_back(
llama_grammar_element( LlamaGrammarElement(
llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id
) )
) )
sub_rule.push_back( sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0))
llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT.value, 0)
)
# if (*pos == '+') { # if (*pos == '+') {
# // add preceding symbol as alternate only for '+' (otherwise empty) # // add preceding symbol as alternate only for '+' (otherwise empty)
# sub_rule.insert( # sub_rule.insert(
@ -918,16 +879,12 @@ def parse_sequence(
out_elements.begin() + last_sym_start, out_elements.begin() + last_sym_start,
out_elements.end(), out_elements.end(),
) )
sub_rule.push_back( sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0))
llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END.value, 0)
)
add_rule(state, sub_rule_id, sub_rule) add_rule(state, sub_rule_id, sub_rule)
# in original rule, replace previous symbol with reference to generated rule # in original rule, replace previous symbol with reference to generated rule
out_elements.resize(last_sym_start) out_elements.resize(last_sym_start)
out_elements.push_back( out_elements.push_back(
llama_grammar_element( LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id)
llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id
)
) )
pos = parse_space(pos + 1, is_nested) pos = parse_space(pos + 1, is_nested)
# } else { # } else {
@ -965,19 +922,13 @@ def parse_alternates(
rule_id: int, rule_id: int,
is_nested: bool, is_nested: bool,
) -> const_char_p: ) -> const_char_p:
rule = std.vector() # type: std.vector[llama_grammar_element] rule = std.vector() # type: std.vector[LlamaGrammarElement]
pos = parse_sequence( pos = parse_sequence(state, src, rule_name, rule, is_nested) # type: const_char_p
state, src, rule_name, rule, is_nested
) # type: const_char_p
while pos[0] == "|": while pos[0] == "|":
rule.push_back( rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0))
llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT.value, 0)
)
pos = parse_space(pos + 1, True) pos = parse_space(pos + 1, True)
pos = parse_sequence(state, pos, rule_name, rule, is_nested) pos = parse_sequence(state, pos, rule_name, rule, is_nested)
rule.push_back( rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0))
llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END.value, 0)
)
add_rule(state, rule_id, rule) add_rule(state, rule_id, rule)
return pos return pos
@ -1017,9 +968,7 @@ def parse_rule(state: parse_state, src: const_char_p) -> const_char_p:
raise RuntimeError("expecting ::= at " + str(pos)) raise RuntimeError("expecting ::= at " + str(pos))
pos = parse_space(pos + 3, True) # type: const_char_p pos = parse_space(pos + 3, True) # type: const_char_p
pos = parse_alternates( pos = parse_alternates(state, pos, name, rule_id, False) # type: const_char_p
state, pos, name, rule_id, False
) # type: const_char_p
if pos[0] == "\r": if pos[0] == "\r":
pos += 2 if pos[1] == "\n" else 1 pos += 2 if pos[1] == "\n" else 1
@ -1080,7 +1029,7 @@ def print_grammar_char(file: TextIO, c: int) -> None:
# default: return false; # default: return false;
# } # }
# } # }
def is_char_element(elem: llama_grammar_element) -> bool: def is_char_element(elem: LlamaGrammarElement) -> bool:
return elem.type in ( return elem.type in (
llama_gretype.LLAMA_GRETYPE_CHAR.value, llama_gretype.LLAMA_GRETYPE_CHAR.value,
llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value, llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value,
@ -1097,7 +1046,7 @@ def is_char_element(elem: llama_grammar_element) -> bool:
def print_rule( def print_rule(
file: TextIO, file: TextIO,
rule_id: int, rule_id: int,
rule: std.vector[llama_grammar_element], rule: std.vector[LlamaGrammarElement],
symbol_id_names: std.map[int, str], symbol_id_names: std.map[int, str],
) -> None: ) -> None:
# if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { # if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) {
@ -1105,13 +1054,9 @@ def print_rule(
# "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); # "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
# } # }
# fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); # fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
if ( if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value:
rule.empty()
or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value
):
raise RuntimeError( raise RuntimeError(
"malformed rule, does not end with LLAMA_GRETYPE_END: " "malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id)
+ str(rule_id)
) )
print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ") print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ")
# for (size_t i = 0, end = rule.size() - 1; i < end; i++) { # for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
@ -1154,22 +1099,20 @@ def print_rule(
# break; # break;
# } # }
for i, elem in enumerate(rule[:-1]): for i, elem in enumerate(rule[:-1]):
case = elem.type # type: int case = elem.type # type: llama_gretype
if case == llama_gretype.LLAMA_GRETYPE_END.value: if case is llama_gretype.LLAMA_GRETYPE_END.value:
raise RuntimeError( raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i))
"unexpected end of rule: " + str(rule_id) + "," + str(i) elif case is llama_gretype.LLAMA_GRETYPE_ALT:
)
elif case == llama_gretype.LLAMA_GRETYPE_ALT.value:
print("| ", file=file, end="") print("| ", file=file, end="")
elif case == llama_gretype.LLAMA_GRETYPE_RULE_REF.value: elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF:
print(f"{symbol_id_names.at(elem.value)} ", file=file, end="") print(f"{symbol_id_names.at(elem.value)} ", file=file, end="")
elif case == llama_gretype.LLAMA_GRETYPE_CHAR.value: elif case is llama_gretype.LLAMA_GRETYPE_CHAR:
print("[", file=file, end="") print("[", file=file, end="")
print_grammar_char(file, elem.value) print_grammar_char(file, elem.value)
elif case == llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value: elif case is llama_gretype.LLAMA_GRETYPE_CHAR_NOT:
print("[^", file=file, end="") print("[^", file=file, end="")
print_grammar_char(file, elem.value) print_grammar_char(file, elem.value)
elif case == llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value: elif case is llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER:
if i == 0 or not is_char_element(rule[i - 1]): if i == 0 or not is_char_element(rule[i - 1]):
raise RuntimeError( raise RuntimeError(
"LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: "
@ -1179,7 +1122,7 @@ def print_rule(
) )
print("-", file=file, end="") print("-", file=file, end="")
print_grammar_char(file, elem.value) print_grammar_char(file, elem.value)
elif case == llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value: elif case is llama_gretype.LLAMA_GRETYPE_CHAR_ALT:
if i == 0 or not is_char_element(rule[i - 1]): if i == 0 or not is_char_element(rule[i - 1]):
raise RuntimeError( raise RuntimeError(
"LLAMA_GRETYPE_CHAR_ALT without preceding char: " "LLAMA_GRETYPE_CHAR_ALT without preceding char: "