Merge branch 'main' of github.com:abetlen/llama_cpp_python into main

This commit is contained in:
Andrei Betlen 2023-08-17 20:53:04 -04:00
commit e632c59fa0

View file

@ -1031,10 +1031,10 @@ def print_grammar_char(file: TextIO, c: int) -> None:
# }
def is_char_element(elem: LlamaGrammarElement) -> bool:
return elem.type in (
llama_gretype.LLAMA_GRETYPE_CHAR.value,
llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value,
llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value,
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value,
llama_gretype.LLAMA_GRETYPE_CHAR,
llama_gretype.LLAMA_GRETYPE_CHAR_NOT,
llama_gretype.LLAMA_GRETYPE_CHAR_ALT,
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER,
)
@ -1054,9 +1054,10 @@ def print_rule(
# "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
# }
# fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value:
if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END:
raise RuntimeError(
"malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id)
"malformed rule, does not end with LLAMA_GRETYPE_END: "
+ str(rule_id)
)
print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ")
# for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
@ -1100,8 +1101,10 @@ def print_rule(
# }
for i, elem in enumerate(rule[:-1]):
case = elem.type # type: llama_gretype
if case is llama_gretype.LLAMA_GRETYPE_END.value:
raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i))
if case is llama_gretype.LLAMA_GRETYPE_END:
raise RuntimeError(
"unexpected end of rule: " + str(rule_id) + "," + str(i)
)
elif case is llama_gretype.LLAMA_GRETYPE_ALT:
print("| ", file=file, end="")
elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF:
@ -1140,8 +1143,8 @@ def print_rule(
# fprintf(file, "] ");
if is_char_element(elem):
if rule[i + 1].type in (
llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value,
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value,
llama_gretype.LLAMA_GRETYPE_CHAR_ALT,
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER,
):
pass
else: