diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 4eb2b02..81ca552 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -339,16 +339,7 @@ def chat_formatter_to_chat_completion_handler( stop = stop + rstop if response_format is not None and response_format["type"] == "json_object": - try: - # create grammar from json schema - if "schema" in response_format: - grammar = llama_grammar.LlamaGrammar.from_json_schema( - json.dumps(response_format["schema"]), verbose=llama.verbose - ) - except Exception as e: - grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF, verbose=llama.verbose - ) + grammar = _grammar_for_response_format(response_format, verbose=llama.verbose) completion_or_chunks = llama.create_completion( prompt=prompt, @@ -606,6 +597,35 @@ def _format_chatglm3( ret += role return ret +def _grammar_for_json(verbose:bool=False): + return llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF, verbose=verbose) + +def _grammar_for_json_schema( + schema: str, + verbose: bool = False, + fallback_to_json: bool = True +): + try: + return llama_grammar.LlamaGrammar.from_json_schema(schema, verbose=verbose) + except Exception as e: + if fallback_to_json: + return _grammar_for_json(verbose=verbose) + else: + raise e + +def _grammar_for_response_format( + response_format: llama_types.ChatCompletionRequestResponseFormat, + verbose: bool = False +): + if response_format["type"] != "json_object": + return None + + if "schema" in response_format: + return _grammar_for_json_schema( + json.dumps(response_format["schema"]), verbose=verbose + ) + else: + return _grammar_for_json(verbose=verbose) ### Chat Formats ### @@ -1994,16 +2014,7 @@ class Llava15ChatHandler: prompt = llama.input_ids[: llama.n_tokens].tolist() if response_format is not None and response_format["type"] == "json_object": - try: - # create grammar from json schema - if "schema" in response_format: - grammar = llama_grammar.LlamaGrammar.from_json_schema( - json.dumps(response_format["schema"]) - ) - except Exception as e: - grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF - ) + grammar = _grammar_for_response_format(response_format) return _convert_completion_to_chat( llama.create_completion( @@ -2159,26 +2170,10 @@ def chatml_function_calling( tool_calls=None, add_generation_prompt=True, ) + if response_format is not None and response_format["type"] == "json_object": - try: - grammar = ( - llama_grammar.LlamaGrammar.from_json_schema( - json.dumps(response_format["schema"]) - ) - if "schema" in response_format - else None - ) - except Exception as e: - if llama.verbose: - print( - "Failed to parse response format as JSON schema, falling back to default grammar" - ) - print(e) - grammar = ( - llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF) - if grammar is None - else grammar - ) + grammar = _grammar_for_response_format(response_format) + return _convert_completion_to_chat( llama.create_completion( prompt=prompt,