fix: json mode
This commit is contained in:
parent
1a9b8af2dd
commit
20e6815252
1 changed files with 34 additions and 39 deletions
|
@ -339,16 +339,7 @@ def chat_formatter_to_chat_completion_handler(
|
|||
stop = stop + rstop
|
||||
|
||||
if response_format is not None and response_format["type"] == "json_object":
|
||||
try:
|
||||
# create grammar from json schema
|
||||
if "schema" in response_format:
|
||||
grammar = llama_grammar.LlamaGrammar.from_json_schema(
|
||||
json.dumps(response_format["schema"]), verbose=llama.verbose
|
||||
)
|
||||
except Exception as e:
|
||||
grammar = llama_grammar.LlamaGrammar.from_string(
|
||||
llama_grammar.JSON_GBNF, verbose=llama.verbose
|
||||
)
|
||||
grammar = _grammar_for_response_format(response_format, verbose=llama.verbose)
|
||||
|
||||
completion_or_chunks = llama.create_completion(
|
||||
prompt=prompt,
|
||||
|
@ -606,6 +597,35 @@ def _format_chatglm3(
|
|||
ret += role
|
||||
return ret
|
||||
|
||||
def _grammar_for_json(verbose:bool=False):
|
||||
return llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF, verbose=verbose)
|
||||
|
||||
def _grammar_for_json_schema(
|
||||
schema: str,
|
||||
verbose: bool = False,
|
||||
fallback_to_json: bool = True
|
||||
):
|
||||
try:
|
||||
return llama_grammar.LlamaGrammar.from_json_schema(schema, verbose=verbose)
|
||||
except Exception as e:
|
||||
if fallback_to_json:
|
||||
return _grammar_for_json(verbose=verbose)
|
||||
else:
|
||||
raise e
|
||||
|
||||
def _grammar_for_response_format(
|
||||
response_format: llama_types.ChatCompletionRequestResponseFormat,
|
||||
verbose: bool = False
|
||||
):
|
||||
if response_format["type"] != "json_object":
|
||||
return None
|
||||
|
||||
if "schema" in response_format:
|
||||
return _grammar_for_json_schema(
|
||||
json.dumps(response_format["schema"]), verbose=verbose
|
||||
)
|
||||
else:
|
||||
return _grammar_for_json(verbose=verbose)
|
||||
|
||||
### Chat Formats ###
|
||||
|
||||
|
@ -1994,16 +2014,7 @@ class Llava15ChatHandler:
|
|||
prompt = llama.input_ids[: llama.n_tokens].tolist()
|
||||
|
||||
if response_format is not None and response_format["type"] == "json_object":
|
||||
try:
|
||||
# create grammar from json schema
|
||||
if "schema" in response_format:
|
||||
grammar = llama_grammar.LlamaGrammar.from_json_schema(
|
||||
json.dumps(response_format["schema"])
|
||||
)
|
||||
except Exception as e:
|
||||
grammar = llama_grammar.LlamaGrammar.from_string(
|
||||
llama_grammar.JSON_GBNF
|
||||
)
|
||||
grammar = _grammar_for_response_format(response_format)
|
||||
|
||||
return _convert_completion_to_chat(
|
||||
llama.create_completion(
|
||||
|
@ -2159,26 +2170,10 @@ def chatml_function_calling(
|
|||
tool_calls=None,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
|
||||
if response_format is not None and response_format["type"] == "json_object":
|
||||
try:
|
||||
grammar = (
|
||||
llama_grammar.LlamaGrammar.from_json_schema(
|
||||
json.dumps(response_format["schema"])
|
||||
)
|
||||
if "schema" in response_format
|
||||
else None
|
||||
)
|
||||
except Exception as e:
|
||||
if llama.verbose:
|
||||
print(
|
||||
"Failed to parse response format as JSON schema, falling back to default grammar"
|
||||
)
|
||||
print(e)
|
||||
grammar = (
|
||||
llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
|
||||
if grammar is None
|
||||
else grammar
|
||||
)
|
||||
grammar = _grammar_for_response_format(response_format)
|
||||
|
||||
return _convert_completion_to_chat(
|
||||
llama.create_completion(
|
||||
prompt=prompt,
|
||||
|
|
Loading…
Reference in a new issue