From a240aa6b25be46aca06a9f3d0fc5b28528385638 Mon Sep 17 00:00:00 2001 From: c0sogi Date: Thu, 17 Aug 2023 21:00:44 +0900 Subject: [PATCH] Fix typos in llama_grammar --- llama_cpp/llama_grammar.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/llama_cpp/llama_grammar.py b/llama_cpp/llama_grammar.py index f35f9fa..8ff1565 100644 --- a/llama_cpp/llama_grammar.py +++ b/llama_cpp/llama_grammar.py @@ -1031,10 +1031,10 @@ def print_grammar_char(file: TextIO, c: int) -> None: # } def is_char_element(elem: LlamaGrammarElement) -> bool: return elem.type in ( - llama_gretype.LLAMA_GRETYPE_CHAR.value, - llama_gretype.LLAMA_GRETYPE_CHAR_NOT.value, - llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value, - llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value, + llama_gretype.LLAMA_GRETYPE_CHAR, + llama_gretype.LLAMA_GRETYPE_CHAR_NOT, + llama_gretype.LLAMA_GRETYPE_CHAR_ALT, + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, ) @@ -1054,9 +1054,10 @@ def print_rule( # "malformed rule, does not end with LLAMA_GRETYPE_END: " + std::to_string(rule_id)); # } # fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); - if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END.value: + if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END: raise RuntimeError( - "malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id) + "malformed rule, does not end with LLAMA_GRETYPE_END: " + + str(rule_id) ) print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ") # for (size_t i = 0, end = rule.size() - 1; i < end; i++) { @@ -1100,8 +1101,10 @@ def print_rule( # } for i, elem in enumerate(rule[:-1]): case = elem.type # type: llama_gretype - if case is llama_gretype.LLAMA_GRETYPE_END.value: - raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i)) + if case is llama_gretype.LLAMA_GRETYPE_END: + raise RuntimeError( + "unexpected end of rule: " + str(rule_id) + "," + str(i) + ) elif case is llama_gretype.LLAMA_GRETYPE_ALT: print("| ", file=file, end="") elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF: @@ -1140,8 +1143,8 @@ def print_rule( # fprintf(file, "] "); if is_char_element(elem): if rule[i + 1].type in ( - llama_gretype.LLAMA_GRETYPE_CHAR_ALT.value, - llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER.value, + llama_gretype.LLAMA_GRETYPE_CHAR_ALT, + llama_gretype.LLAMA_GRETYPE_CHAR_RNG_UPPER, ): pass else: