Add json schema mode (#1122)

* Add json schema mode

* Add llava chat format support
This commit is contained in:
Andrei 2024-01-27 16:52:18 -05:00 committed by GitHub
parent c6d3bd62e8
commit d8f6914f45
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 17 additions and 5 deletions

View file

@ -318,7 +318,14 @@ def chat_formatter_to_chat_completion_handler(
stop = stop + rstop stop = stop + rstop
if response_format is not None and response_format["type"] == "json_object": if response_format is not None and response_format["type"] == "json_object":
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF) try:
# create grammar from json schema
if "schema" in response_format:
grammar = llama_grammar.LlamaGrammar.from_json_schema(
json.dumps(response_format["schema"])
)
except Exception as e:
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
completion_or_chunks = llama.create_completion( completion_or_chunks = llama.create_completion(
prompt=prompt, prompt=prompt,
@ -1434,10 +1441,14 @@ class Llava15ChatHandler:
prompt = llama.input_ids[: llama.n_tokens].tolist() prompt = llama.input_ids[: llama.n_tokens].tolist()
if response_format is not None and response_format["type"] == "json_object": if response_format is not None and response_format["type"] == "json_object":
with suppress_stdout_stderr(disable=self.verbose): try:
grammar = llama_grammar.LlamaGrammar.from_string( # create grammar from json schema
llama_grammar.JSON_GBNF if "schema" in response_format:
) grammar = llama_grammar.LlamaGrammar.from_json_schema(
json.dumps(response_format["schema"])
)
except Exception as e:
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
return _convert_completion_to_chat( return _convert_completion_to_chat(
llama.create_completion( llama.create_completion(

View file

@ -154,6 +154,7 @@ class ChatCompletionFunctionCallOption(TypedDict):
class ChatCompletionRequestResponseFormat(TypedDict): class ChatCompletionRequestResponseFormat(TypedDict):
type: Literal["text", "json_object"] type: Literal["text", "json_object"]
schema: NotRequired[JsonType] # https://docs.endpoints.anyscale.com/guides/json_mode/
class ChatCompletionRequestMessageContentPartText(TypedDict): class ChatCompletionRequestMessageContentPartText(TypedDict):