diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index 5388676..f35f9fa 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -40,7 +40,7 @@ class Sentinel: class LlamaGrammar: """Keeps reference counts of all the arguments, so that they are not garbage collected by Python.""" - + def __del__(self) -> None: """Free the grammar pointer when the object is deleted.""" if self.grammar is not None: @@ -52,33 +52,12 @@ class LlamaGrammar: parsed_grammar: "parse_state", ) -> None: """Initialize the grammar pointer from the parsed state.""" - self.parsed_grammar = parsed_grammar - grammar_rules = ( + self._grammar_rules = ( parsed_grammar.c_rules() - ) # type: std.vector[std.vector[llama_grammar_element]] - - # Step 1: Convert each list to llama_grammar_element array and get pointer - self.element_arrays = [ - (llama_grammar_element * len(sublist))(*sublist) - for sublist in grammar_rules - ] # type: List[Array[llama_grammar_element]] - - # Step 2: Get pointer of each array - self.element_array_pointers = [ - cast(subarray, llama_grammar_element_p) - for subarray in self.element_arrays - ] # type: List[llama_grammar_element_p] - - # Step 3: Make array of these pointers and get its pointer - self.rules = ( - llama_grammar_element_p * len(self.element_array_pointers) - )(*self.element_array_pointers) - - self.n_rules = c_size_t(grammar_rules.size()) - self.start_rule_index = c_size_t(parsed_grammar.symbol_ids.at("root")) - self._grammar = llama_cpp.llama_grammar_init( - self.rules, self.n_rules, self.start_rule_index - ) + ) # type: std.vector[std.vector[LlamaGrammarElement]] + self._n_rules = self._grammar_rules.size() # type: int + self._start_rule_index = parsed_grammar.symbol_ids.at("root") # type: int + self.grammar = self.init() @classmethod def from_string(cls, grammar: str, verbose: bool = True) -> "LlamaGrammar": @@ -110,23 +89,45 @@ class LlamaGrammar: f"{cls.from_file.__name__}: error parsing grammar file: params_grammer is empty" ) - @property - def grammar(self) -> llama_grammar_p: - if self._grammar is None: - raise ValueError( - f"{self.__class__.__name__}.grammar: grammar is freed" - ) - return self._grammar - - @grammar.setter - def grammar(self, value: Optional[llama_grammar_p]) -> None: - self._grammar = value + def init(self) -> None: + # Step 1: Convert LlamaGrammarElement to llama_grammar_element + self._element_lists = [ + [ + llama_grammar_element(c_int(elem.type.value), c_uint32(elem.value)) + for elem in subvector + ] + for subvector in self._grammar_rules + ] # type: List[List[llama_grammar_element]] + + # Step 2: Convert each list to llama_grammar_element array and get pointer + self._element_arrays = [ + (llama_grammar_element * len(sublist))(*sublist) + for sublist in self._element_lists + ] # type: List[Array[llama_grammar_element]] + + # Step 3: Get pointer of each array + self._element_array_pointers = [ + cast(subarray, llama_grammar_element_p) for subarray in self._element_arrays + ] # type: List[llama_grammar_element_p] + + # Step 4: Make array of these pointers and get its pointer + self._rules = (llama_grammar_element_p * len(self._element_array_pointers))( + *self._element_array_pointers + ) + self.grammar = llama_cpp.llama_grammar_init( + self._rules, c_size_t(self._n_rules), c_size_t(self._start_rule_index) + ) def reset(self) -> None: - llama_cpp.llama_grammar_free(self.grammar) - self.grammar = llama_cpp.llama_grammar_init( - self.rules, self.n_rules, self.start_rule_index - ) + if self.grammar is not None: + llama_cpp.llama_grammar_free(self.grammar) + self.init() + + +class LlamaGrammarElement: + def __init__(self, type: "llama_gretype", value: int): + self.type = type + self.value = value # Unicode code point or rule ID class const_char_p: @@ -182,21 +183,15 @@ class const_char_p: ) def __eq__(self: Ptr, other: Ptr) -> bool: - assert ( - self.value == other.value - ), "comparing pointers from different strings" + assert self.value == other.value, "comparing pointers from different strings" return self.pos == other.pos def __lt__(self: Ptr, other: Ptr) -> bool: - assert ( - self.value == other.value - ), "comparing pointers from different strings" + assert self.value == other.value, "comparing pointers from different strings" return self.pos < other.pos def __gt__(self: Ptr, other: Ptr) -> bool: - assert ( - self.value == other.value - ), "comparing pointers from different strings" + assert self.value == other.value, "comparing pointers from different strings" return self.pos > other.pos @@ -220,9 +215,7 @@ class std: def _check_version(self): if self._version != self._vector._version: - raise RuntimeError( - "Iterator used after vector was modified." - ) + raise RuntimeError("Iterator used after vector was modified.") def __iter__(self): return self @@ -280,16 +273,12 @@ class std: ) -> None: if new_size > self.size(): if fill_value_factory is None: - raise ValueError( - "A fill value factory function must be provided." - ) + raise ValueError("A fill value factory function must be provided.") self.reserve(new_size, fill_value_factory) elif new_size < self.size(): self[:] = self[:new_size] - def reserve( - self, capacity: int, fill_value_factory: Callable[[], T] - ) -> None: + def reserve(self, capacity: int, fill_value_factory: Callable[[], T]) -> None: if capacity > self.size(): fill_value = fill_value_factory() self.extend([fill_value] * (capacity - self.size())) @@ -401,9 +390,7 @@ class std: for k in keys: if k >= key: return self.iterator(self, k) - raise ValueError( - "No key found that is not less than the input key" - ) + raise ValueError("No key found that is not less than the input key") except TypeError: raise TypeError("Keys of type T cannot be sorted.") @@ -460,9 +447,7 @@ class llama_gretype(Enum): class parse_state: def __init__(self): self.symbol_ids: std.map[str, int] = std.map() - self.rules: std.vector[ - std.vector[llama_grammar_element] - ] = std.vector() + self.rules: std.vector[std.vector[LlamaGrammarElement]] = std.vector() # std::vector parse_state::c_rules() { # std::vector ret; @@ -471,16 +456,16 @@ class parse_state: # } # return ret; # } - def c_rules(self) -> std.vector[std.vector[llama_grammar_element]]: - ret = ( - std.vector() - ) # type: std.vector[std.vector[llama_grammar_element]] + def c_rules(self) -> std.vector[std.vector[LlamaGrammarElement]]: + ret = std.vector() # type: std.vector[std.vector[LlamaGrammarElement]] for rule in self.rules: ret.push_back(rule.data()) return ret def __repr__(self) -> str: - return f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})" + return ( + f"parse_state(symbol_ids={len(self.symbol_ids)}, rules={len(self.rules)})" + ) # struct llama_grammar { @@ -531,12 +516,12 @@ def generate_symbol_id(state: parse_state, base_name: str) -> int: def add_rule( state: parse_state, rule_id: int, - rule: std.vector[llama_grammar_element], + rule: std.vector[LlamaGrammarElement], ) -> None: if state.rules.size() <= rule_id: state.rules.resize( rule_id + 1, - fill_value_factory=std.vector[llama_grammar_element], + fill_value_factory=std.vector[LlamaGrammarElement], ) state.rules[rule_id] = rule @@ -575,9 +560,7 @@ def decode_utf8(src: const_char_p) -> Tuple[int, const_char_p]: # return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '-' || ('0' <= c && c <= '9'); # } def is_word_char(c: str) -> bool: - return ( - ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9") - ) + return ("a" <= c <= "z") or ("A" <= c <= "Z") or c == "-" or ("0" <= c <= "9") # std::pair parse_hex(const char * src, int size) { @@ -619,9 +602,7 @@ def parse_hex(src: const_char_p, size: int) -> Tuple[int, const_char_p]: break pos += 1 if pos != end: - raise RuntimeError( - "expecting " + str(size) + " hex chars at " + str(src) - ) + raise RuntimeError("expecting " + str(size) + " hex chars at " + str(src)) return (value, pos) @@ -707,9 +688,7 @@ def parse_name(src: const_char_p) -> const_char_p: # } def parse_space(src: const_char_p, newline_ok: bool) -> const_char_p: pos = const_char_p(src) # type: const_char_p - while pos[0] in (" ", "\t", "#") or ( - newline_ok and pos[0] in ("\r", "\n") - ): + while pos[0] in (" ", "\t", "#") or (newline_ok and pos[0] in ("\r", "\n")): if pos[0] == "#": while pos[0] is not None and pos[0] not in ("\r", "\n"): pos += 1 @@ -728,7 +707,7 @@ def parse_sequence( state: parse_state, src: const_char_p, rule_name: str, - out_elements: std.vector[llama_grammar_element], + out_elements: std.vector[LlamaGrammarElement], is_nested: bool, ) -> const_char_p: # size_t last_sym_start = out_elements.size(); @@ -753,9 +732,7 @@ def parse_sequence( char_pair = parse_char(pos) # type: Tuple[int, const_char_p] pos = char_pair[1] out_elements.push_back( - llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_CHAR.value, char_pair[0] - ) + LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_CHAR, char_pair[0]) ) pos = parse_space(pos + 1, is_nested) # } else if (*pos == '[') { // char range(s) @@ -763,9 +740,7 @@ def parse_sequence( # enum llama_gretype start_type = LLAMA_GRETYPE_CHAR; elif pos[0] == "[": # char range(s) pos += 1 - start_type = ( - llama_gretype.LLAMA_GRETYPE_CHAR - ) # type: llama_gretype + start_type = llama_gretype.LLAMA_GRETYPE_CHAR # type: llama_gretype # if (*pos == '^') { # pos++; # start_type = LLAMA_GRETYPE_CHAR_NOT; @@ -790,9 +765,7 @@ def parse_sequence( if last_sym_start < out_elements.size() else start_type ) # type: llama_gretype - out_elements.push_back( - llama_grammar_element(type.value, char_pair[0]) - ) + out_elements.push_back(LlamaGrammarElement(type, char_pair[0])) # if (pos[0] == '-' && pos[1] != ']') { # auto endchar_pair = parse_char(pos + 1); # pos = endchar_pair.second; @@ -800,13 +773,11 @@ def parse_sequence( # } # } if pos[0] == "-" and pos[1] != "]": - endchar_pair = parse_char( - pos + 1 - ) # type: Tuple[int, const_char_p] + endchar_pair = parse_char(pos + 1) # type: Tuple[int, const_char_p] pos = endchar_pair[1] out_elements.push_back( - llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value, + LlamaGrammarElement( + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, endchar_pair[0], ) ) @@ -820,15 +791,11 @@ def parse_sequence( # out_elements.push_back({LLAMA_GRETYPE_RULE_REF, ref_rule_id}); elif is_word_char(pos[0]): # rule reference name_end = parse_name(pos) # type: const_char_p - ref_rule_id = get_symbol_id( - state, pos, name_end - pos - ) # type: int + ref_rule_id = get_symbol_id(state, pos, name_end - pos) # type: int pos = parse_space(name_end, is_nested) last_sym_start = out_elements.size() out_elements.push_back( - llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF.value, ref_rule_id - ) + LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, ref_rule_id) ) # } else if (*pos == '(') { // grouping # // parse nested alternates into synthesized rule @@ -850,9 +817,7 @@ def parse_sequence( last_sym_start = out_elements.size() # output reference to synthesized rule out_elements.push_back( - llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id - ) + LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id) ) if pos[0] != ")": raise RuntimeError("expecting ')' at " + str(pos)) @@ -863,9 +828,7 @@ def parse_sequence( # } elif pos[0] in ("*", "+", "?"): # repetition operator if last_sym_start == out_elements.size(): - raise RuntimeError( - "expecting preceding item to */+/? at " + str(pos) - ) + raise RuntimeError("expecting preceding item to */+/? at " + str(pos)) # // apply transformation to previous symbol (last_sym_start to end) according to # // rewrite rules: # // S* --> S' ::= S S' | @@ -878,8 +841,8 @@ def parse_sequence( # sub_rule.end(), out_elements.begin() + last_sym_start, out_elements.end()); sub_rule_id = generate_symbol_id(state, rule_name) # type: int sub_rule = std.vector[ - llama_grammar_element - ]() # type: std.vector[llama_grammar_element] + LlamaGrammarElement + ]() # type: std.vector[LlamaGrammarElement] sub_rule.insert( sub_rule.end(), out_elements.begin() + last_sym_start, @@ -893,13 +856,11 @@ def parse_sequence( # sub_rule.push_back({LLAMA_GRETYPE_ALT, 0}); if pos[0] in ("*", "+"): sub_rule.push_back( - llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id + LlamaGrammarElement( + llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id ) ) - sub_rule.push_back( - llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT.value, 0) - ) + sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0)) # if (*pos == '+') { # // add preceding symbol as alternate only for '+' (otherwise empty) # sub_rule.insert( @@ -918,16 +879,12 @@ def parse_sequence( out_elements.begin() + last_sym_start, out_elements.end(), ) - sub_rule.push_back( - llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END.value, 0) - ) + sub_rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0)) add_rule(state, sub_rule_id, sub_rule) # in original rule, replace previous symbol with reference to generated rule out_elements.resize(last_sym_start) out_elements.push_back( - llama_grammar_element( - llama_gretype.LLAMA_GRETYPE_RULE_REF.value, sub_rule_id - ) + LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_RULE_REF, sub_rule_id) ) pos = parse_space(pos + 1, is_nested) # } else { @@ -965,19 +922,13 @@ def parse_alternates( rule_id: int, is_nested: bool, ) -> const_char_p: - rule = std.vector() # type: std.vector[llama_grammar_element] - pos = parse_sequence( - state, src, rule_name, rule, is_nested - ) # type: const_char_p + rule = std.vector() # type: std.vector[LlamaGrammarElement] + pos = parse_sequence(state, src, rule_name, rule, is_nested) # type: const_char_p while pos[0] == "|": - rule.push_back( - llama_grammar_element(llama_gretype.LLAMA_GRETYPE_ALT.value, 0) - ) + rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_ALT, 0)) pos = parse_space(pos + 1, True) pos = parse_sequence(state, pos, rule_name, rule, is_nested) - rule.push_back( - llama_grammar_element(llama_gretype.LLAMA_GRETYPE_END.value, 0) - ) + rule.push_back(LlamaGrammarElement(llama_gretype.LLAMA_GRETYPE_END, 0)) add_rule(state, rule_id, rule) return pos @@ -1017,9 +968,7 @@ def parse_rule(state: parse_state, src: const_char_p) -> const_char_p: raise RuntimeError("expecting ::= at " + str(pos)) pos = parse_space(pos + 3, True) # type: const_char_p - pos = parse_alternates( - state, pos, name, rule_id, False - ) # type: const_char_p + pos = parse_alternates(state, pos, name, rule_id, False) # type: const_char_p if pos[0] == "\r": pos += 2 if pos[1] == "\n" else 1 @@ -1080,7 +1029,7 @@ def print_grammar_char(file: TextIO, c: int) -> None: # default: return false; # } # } -def is_char_element(elem: llama_grammar_element) -> bool: +def is_char_element(elem: LlamaGrammarElement) -> bool: return elem.type in ( llama_gretype.LLAMA_GRETYPE_CHAR.value, llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value, @@ -1097,7 +1046,7 @@ def is_char_element(elem: llama_grammar_element) -> bool: def print_rule( file: TextIO, rule_id: int, - rule: std.vector[llama_grammar_element], + rule: std.vector[LlamaGrammarElement], symbol_id_names: std.map[int, str], ) -> None: # if (rule.empty() || rule.back().type != LLAMA_GRETYPE_END) { @@ -1105,13 +1054,9 @@ def print_rule( # "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); # } # fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); - if ( - rule.empty() - or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value - ): + if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value: raise RuntimeError( - "malformed rule, does not end with LLAMA_GRETYPE_END: " - + str(rule_id) + "malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id) ) print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ") # for (size_t i = 0, end = rule.size() - 1; i < end; i++) { @@ -1154,22 +1099,20 @@ def print_rule( # break; # } for i, elem in enumerate(rule[:-1]): - case = elem.type # type: int - if case == llama_gretype.LLAMA_GRETYPE_END.value: - raise RuntimeError( - "unexpected end of rule: " + str(rule_id) + "," + str(i) - ) - elif case == llama_gretype.LLAMA_GRETYPE_ALT.value: + case = elem.type # type: llama_gretype + if case is llama_gretype.LLAMA_GRETYPE_END.value: + raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i)) + elif case is llama_gretype.LLAMA_GRETYPE_ALT: print("| ", file=file, end="") - elif case == llama_gretype.LLAMA_GRETYPE_RULE_REF.value: + elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF: print(f"{symbol_id_names.at(elem.value)} ", file=file, end="") - elif case == llama_gretype.LLAMA_GRETYPE_CHAR.value: + elif case is llama_gretype.LLAMA_GRETYPE_CHAR: print("[", file=file, end="") print_grammar_char(file, elem.value) - elif case == llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value: + elif case is llama_gretype.LLAMA_GRETYPE_CHAR_NOT: print("[^", file=file, end="") print_grammar_char(file, elem.value) - elif case == llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value: + elif case is llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER: if i == 0 or not is_char_element(rule[i - 1]): raise RuntimeError( "LLAMA_GRETYPE_CHAR_RNG_UPPER without preceding char: " @@ -1179,7 +1122,7 @@ def print_rule( ) print("-", file=file, end="") print_grammar_char(file, elem.value) - elif case == llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value: + elif case is llama_gretype.LLAMA_GRETYPE_CHAR_ALT: if i == 0 or not is_char_element(rule[i - 1]): raise RuntimeError( "LLAMA_GRETYPE_CHAR_ALT without preceding char: " @@ -1239,4 +1182,4 @@ def print_grammar(file: TextIO, state: parse_state) -> None: print( f"{print_grammar.__name__}: error printing grammar: {err}", file=sys.stderr, - ) \ No newline at end of file + )