Fix typos in llama_grammar

This commit is contained in:
c0sogi 2023-08-17 21:00:44 +09:00
parent 6dfb98117e
commit a240aa6b25

View file

@ -1031,10 +1031,10 @@ def print_grammar_char(file: TextIO, c: int) -> None:
# } # }
def is_char_element(elem: LlamaGrammarElement) -> bool: def is_char_element(elem: LlamaGrammarElement) -> bool:
return elem.type in ( return elem.type in (
llama_gretype.LLAMA_GRETYPE_CHAR.value, llama_gretype.LLAMA_GRETYPE_CHAR,
llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value, llama_gretype.LLAMA_GRETYPE_CHAR_NOT,
llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value, llama_gretype.LLAMA_GRETYPE_CHAR_ALT,
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value, llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER,
) )
@ -1054,9 +1054,10 @@ def print_rule(
# "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); # "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id));
# } # }
# fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); # fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value: if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END:
raise RuntimeError( raise RuntimeError(
"malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id) "malformed rule, does not end with LLAMA_GRETYPE_END: "
+ str(rule_id)
) )
print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ") print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ")
# for (size_t i = 0, end = rule.size() - 1; i < end; i++) { # for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
@ -1100,8 +1101,10 @@ def print_rule(
# } # }
for i, elem in enumerate(rule[:-1]): for i, elem in enumerate(rule[:-1]):
case = elem.type # type: llama_gretype case = elem.type # type: llama_gretype
if case is llama_gretype.LLAMA_GRETYPE_END.value: if case is llama_gretype.LLAMA_GRETYPE_END:
raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i)) raise RuntimeError(
"unexpected end of rule: " + str(rule_id) + "," + str(i)
)
elif case is llama_gretype.LLAMA_GRETYPE_ALT: elif case is llama_gretype.LLAMA_GRETYPE_ALT:
print("| ", file=file, end="") print("| ", file=file, end="")
elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF: elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF:
@ -1140,8 +1143,8 @@ def print_rule(
# fprintf(file, "] "); # fprintf(file, "] ");
if is_char_element(elem): if is_char_element(elem):
if rule[i + 1].type in ( if rule[i + 1].type in (
llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value, llama_gretype.LLAMA_GRETYPE_CHAR_ALT,
llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value, llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER,
): ):
pass pass
else: else: