From d8f6914f459c1a8536ace5b379492a482fa0db16 Mon Sep 17 00:00:00 2001 From: Andrei Date: Sat, 27 Jan 2024 16:52:18 -0500 Subject: [PATCH] Add json schema mode (#1122) * Add json schema mode * Add llava chat format support --- llama_cpp/llama_chat_format.py | 21 ++++++++++++++++----- llama_cpp/llama_types.py | 1 + 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 6c274aa..e418d40 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -318,7 +318,14 @@ def chat_formatter_to_chat_completion_handler( stop = stop + rstop if response_format is not None and response_format["type"] == "json_object": - grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF) + try: + # create grammar from json schema + if "schema" in response_format: + grammar = llama_grammar.LlamaGrammar.from_json_schema( + json.dumps(response_format["schema"]) + ) + except Exception as e: + grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF) completion_or_chunks = llama.create_completion( prompt=prompt, @@ -1434,10 +1441,14 @@ class Llava15ChatHandler: prompt = llama.input_ids[: llama.n_tokens].tolist() if response_format is not None and response_format["type"] == "json_object": - with suppress_stdout_stderr(disable=self.verbose): - grammar = llama_grammar.LlamaGrammar.from_string( - llama_grammar.JSON_GBNF - ) + try: + # create grammar from json schema + if "schema" in response_format: + grammar = llama_grammar.LlamaGrammar.from_json_schema( + json.dumps(response_format["schema"]) + ) + except Exception as e: + grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF) return _convert_completion_to_chat( llama.create_completion( diff --git a/llama_cpp/llama_types.py b/llama_cpp/llama_types.py index 5b51e98..c3deba8 100644 --- a/llama_cpp/llama_types.py +++ b/llama_cpp/llama_types.py @@ -154,6 +154,7 @@ class ChatCompletionFunctionCallOption(TypedDict): class ChatCompletionRequestResponseFormat(TypedDict): type: Literal["text", "json_object"] + schema: NotRequired[JsonType] # https://docs.endpoints.anyscale.com/guides/json_mode/ class ChatCompletionRequestMessageContentPartText(TypedDict):