diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 6c274aa..e418d40 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -318,7 +318,14 @@ def chat_formatter_to_chat_completion_handler( stop = stop + rstop if response_format is not None and response_format["type"] == "json_object": - grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF) + try: + # create grammar from json schema + if "schema" in response_format: + grammar = llama_grammar.LlamaGrammar.from_json_schema( + json.dumps(response_format["schema"]) + ) + except Exception as e: + grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF) completion_or_chunks = llama.create_completion( prompt=prompt, @@ -1434,10 +1441,14 @@ class Llava15ChatHandler: prompt = llama.input_ids[: llama.n_tokens].tolist() if response_format is not None and response_format["type"] == "json_object": - with suppress_stdout_stderr(disable=self.verbose): - grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF - ) + try: + # create grammar from json schema + if "schema" in response_format: + grammar = llama_grammar.LlamaGrammar.from_json_schema( + json.dumps(response_format["schema"]) + ) + except Exception as e: + grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF) return _convert_completion_to_chat( llama.create_completion( diff --git a/llama_cpp/llama_types.py b/llama_cpp/llama_types.py index 5b51e98..c3deba8 100644 --- a/llama_cpp/llama_types.py +++ b/llama_cpp/llama_types.py @@ -154,6 +154,7 @@ class ChatCompletionFunctionCallOption(TypedDict): class ChatCompletionRequestResponseFormat(TypedDict): type: Literal["text", "json_object"] + schema: NotRequired[JsonType] # https://docs.endpoints.anyscale.com/guides/json_mode/ class ChatCompletionRequestMessageContentPartText(TypedDict):