Bugfix: missing response_format for functionary and llava chat handlers

This commit is contained in:
Andrei Betlen 2023-11-09 00:55:23 -05:00
parent 80f4162bf4
commit b62c449839

View file

@ -320,8 +320,9 @@ def register_chat_format(name: str):
stop = stop + rstop stop = stop + rstop
if response_format is not None and response_format["type"] == "json_object": if response_format is not None and response_format["type"] == "json_object":
print("hello world") grammar = llama_grammar.LlamaGrammar.from_string(
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF) llama_grammar.JSON_GBNF
)
completion_or_chunks = llama.create_completion( completion_or_chunks = llama.create_completion(
prompt=prompt, prompt=prompt,
@ -577,6 +578,7 @@ def functionary_chat_handler(
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
stop: Optional[Union[str, List[str]]] = [], stop: Optional[Union[str, List[str]]] = [],
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
max_tokens: int = 256, max_tokens: int = 256,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
@ -753,6 +755,10 @@ def functionary_chat_handler(
assert isinstance(function_call, str) assert isinstance(function_call, str)
assert stream is False # TODO: support stream mode assert stream is False # TODO: support stream mode
if response_format is not None and response_format["type"] == "json_object":
with suppress_stdout_stderr(disable=llama.verbose):
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
return llama_types.CreateChatCompletionResponse( return llama_types.CreateChatCompletionResponse(
id="chat" + completion["id"], id="chat" + completion["id"],
object="chat.completion", object="chat.completion",
@ -825,6 +831,9 @@ class Llava15ChatHandler:
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
stop: Optional[Union[str, List[str]]] = [], stop: Optional[Union[str, List[str]]] = [],
response_format: Optional[
llama_types.ChatCompletionRequestResponseFormat
] = None,
max_tokens: int = 256, max_tokens: int = 256,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
@ -851,7 +860,6 @@ class Llava15ChatHandler:
if system_prompt != "" if system_prompt != ""
else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
) )
system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
user_role = "\nUSER:" user_role = "\nUSER:"
assistant_role = "\nASSISTANT:" assistant_role = "\nASSISTANT:"
llama.reset() llama.reset()
@ -890,12 +898,14 @@ class Llava15ChatHandler:
ctypes.c_ubyte * len(data_array) ctypes.c_ubyte * len(data_array)
).from_buffer(data_array) ).from_buffer(data_array)
with suppress_stdout_stderr(disable=self.verbose): with suppress_stdout_stderr(disable=self.verbose):
embed = self._llava_cpp.llava_image_embed_make_with_bytes( embed = (
self._llava_cpp.llava_image_embed_make_with_bytes(
ctx_clip=self.clip_ctx, ctx_clip=self.clip_ctx,
n_threads=llama.context_params.n_threads, n_threads=llama.context_params.n_threads,
image_bytes=c_ubyte_ptr, image_bytes=c_ubyte_ptr,
image_bytes_length=len(image_bytes), image_bytes_length=len(image_bytes),
) )
)
try: try:
n_past = ctypes.c_int(llama.n_tokens) n_past = ctypes.c_int(llama.n_tokens)
n_past_p = ctypes.pointer(n_past) n_past_p = ctypes.pointer(n_past)
@ -917,9 +927,17 @@ class Llava15ChatHandler:
f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False
) )
) )
assert llama.n_ctx() >= llama.n_tokens
llama.eval(llama.tokenize(f"{assistant_role}".encode("utf8"), add_bos=False)) llama.eval(llama.tokenize(f"{assistant_role}".encode("utf8"), add_bos=False))
assert llama.n_ctx() >= llama.n_tokens
prompt = llama.input_ids[:llama.n_tokens].tolist() prompt = llama.input_ids[: llama.n_tokens].tolist()
if response_format is not None and response_format["type"] == "json_object":
with suppress_stdout_stderr(disable=self.verbose):
grammar = llama_grammar.LlamaGrammar.from_string(
llama_grammar.JSON_GBNF
)
return _convert_completion_to_chat( return _convert_completion_to_chat(
llama.create_completion( llama.create_completion(