fix: json mode

This commit is contained in:
Andrei Betlen 2024-03-15 12:58:34 -04:00
parent 1a9b8af2dd
commit 20e6815252

View file

@ -339,16 +339,7 @@ def chat_formatter_to_chat_completion_handler(
stop = stop + rstop
if response_format is not None and response_format["type"] == "json_object":
try:
# create grammar from json schema
if "schema" in response_format:
grammar = llama_grammar.LlamaGrammar.from_json_schema(
json.dumps(response_format["schema"]), verbose=llama.verbose
)
except Exception as e:
grammar = llama_grammar.LlamaGrammar.from_string(
llama_grammar.JSON_GBNF, verbose=llama.verbose
)
grammar = _grammar_for_response_format(response_format, verbose=llama.verbose)
completion_or_chunks = llama.create_completion(
prompt=prompt,
@ -606,6 +597,35 @@ def _format_chatglm3(
ret += role
return ret
def _grammar_for_json(verbose:bool=False):
return llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF, verbose=verbose)
def _grammar_for_json_schema(
schema: str,
verbose: bool = False,
fallback_to_json: bool = True
):
try:
return llama_grammar.LlamaGrammar.from_json_schema(schema, verbose=verbose)
except Exception as e:
if fallback_to_json:
return _grammar_for_json(verbose=verbose)
else:
raise e
def _grammar_for_response_format(
response_format: llama_types.ChatCompletionRequestResponseFormat,
verbose: bool = False
):
if response_format["type"] != "json_object":
return None
if "schema" in response_format:
return _grammar_for_json_schema(
json.dumps(response_format["schema"]), verbose=verbose
)
else:
return _grammar_for_json(verbose=verbose)
### Chat Formats ###
@ -1994,16 +2014,7 @@ class Llava15ChatHandler:
prompt = llama.input_ids[: llama.n_tokens].tolist()
if response_format is not None and response_format["type"] == "json_object":
try:
# create grammar from json schema
if "schema" in response_format:
grammar = llama_grammar.LlamaGrammar.from_json_schema(
json.dumps(response_format["schema"])
)
except Exception as e:
grammar = llama_grammar.LlamaGrammar.from_string(
llama_grammar.JSON_GBNF
)
grammar = _grammar_for_response_format(response_format)
return _convert_completion_to_chat(
llama.create_completion(
@ -2159,26 +2170,10 @@ def chatml_function_calling(
tool_calls=None,
add_generation_prompt=True,
)
if response_format is not None and response_format["type"] == "json_object":
try:
grammar = (
llama_grammar.LlamaGrammar.from_json_schema(
json.dumps(response_format["schema"])
)
if "schema" in response_format
else None
)
except Exception as e:
if llama.verbose:
print(
"Failed to parse response format as JSON schema, falling back to default grammar"
)
print(e)
grammar = (
llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
if grammar is None
else grammar
)
grammar = _grammar_for_response_format(response_format)
return _convert_completion_to_chat(
llama.create_completion(
prompt=prompt,