Merge branch 'main' of github.com:abetlen/llama_cpp_python into main
This commit is contained in:
commit
e632c59fa0
1 changed files with 13 additions and 10 deletions
|
@ -1031,10 +1031,10 @@ def print_grammar_char(file: TextIO, c: int) -> None:
|
|||
# }
|
||||
def is_char_element(elem: LlamaGrammarElement) -> bool:
|
||||
return elem.type in (
|
||||
llama_gretype.LLAMA_GRETYPE_CHAR.value,
|
||||
llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value,
|
||||
llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value,
|
||||
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value,
|
||||
llama_gretype.LLAMA_GRETYPE_CHAR,
|
||||
llama_gretype.LLAMA_GRETYPE_CHAR_NOT,
|
||||
llama_gretype.LLAMA_GRETYPE_CHAR_ALT,
|
||||
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1054,9 +1054,10 @@ def print_rule(
|
|||
# "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
|
||||
# }
|
||||
# fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
||||
if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value:
|
||||
if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END:
|
||||
raise RuntimeError(
|
||||
"malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id)
|
||||
"malformed rule, does not end with LLAMA_GRETYPE_END: "
|
||||
+ str(rule_id)
|
||||
)
|
||||
print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ")
|
||||
# for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
||||
|
@ -1100,8 +1101,10 @@ def print_rule(
|
|||
# }
|
||||
for i, elem in enumerate(rule[:-1]):
|
||||
case = elem.type # type: llama_gretype
|
||||
if case is llama_gretype.LLAMA_GRETYPE_END.value:
|
||||
raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i))
|
||||
if case is llama_gretype.LLAMA_GRETYPE_END:
|
||||
raise RuntimeError(
|
||||
"unexpected end of rule: " + str(rule_id) + "," + str(i)
|
||||
)
|
||||
elif case is llama_gretype.LLAMA_GRETYPE_ALT:
|
||||
print("| ", file=file, end="")
|
||||
elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF:
|
||||
|
@ -1140,8 +1143,8 @@ def print_rule(
|
|||
# fprintf(file, "] ");
|
||||
if is_char_element(elem):
|
||||
if rule[i + 1].type in (
|
||||
llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value,
|
||||
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value,
|
||||
llama_gretype.LLAMA_GRETYPE_CHAR_ALT,
|
||||
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER,
|
||||
):
|
||||
pass
|
||||
else:
|
||||
|
|
Loading…
Reference in a new issue