Bugfix: missing response_format for functionary and llava chat handlers

This commit is contained in:
Andrei Betlen 2023-11-09 00:55:23 -05:00
parent 80f4162bf4
commit b62c449839

View file

@ -318,10 +318,11 @@ def register_chat_format(name: str):
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
stop = stop + rstop
if response_format is not None and response_format["type"] == "json_object":
print("hello world")
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
grammar = llama_grammar.LlamaGrammar.from_string(
llama_grammar.JSON_GBNF
)
completion_or_chunks = llama.create_completion(
prompt=prompt,
@ -577,6 +578,7 @@ def functionary_chat_handler(
top_k: int = 40,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
max_tokens: int = 256,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
@ -753,6 +755,10 @@ def functionary_chat_handler(
assert isinstance(function_call, str)
assert stream is False # TODO: support stream mode
if response_format is not None and response_format["type"] == "json_object":
with suppress_stdout_stderr(disable=llama.verbose):
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
return llama_types.CreateChatCompletionResponse(
id="chat" + completion["id"],
object="chat.completion",
@ -785,11 +791,11 @@ class Llava15ChatHandler:
self._llava_cpp = llava_cpp
self.clip_model_path = clip_model_path
self.verbose = verbose
self._clip_free = self._llava_cpp._libllava.clip_free # type: ignore
self._clip_free = self._llava_cpp._libllava.clip_free # type: ignore
with suppress_stdout_stderr(disable=self.verbose):
self.clip_ctx = self._llava_cpp.clip_model_load(
self.clip_model_path.encode(), 0
self.clip_model_path.encode(), 0
)
def __del__(self):
@ -825,6 +831,9 @@ class Llava15ChatHandler:
top_k: int = 40,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
response_format: Optional[
llama_types.ChatCompletionRequestResponseFormat
] = None,
max_tokens: int = 256,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
@ -851,7 +860,6 @@ class Llava15ChatHandler:
if system_prompt != ""
else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
)
system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
user_role = "\nUSER:"
assistant_role = "\nASSISTANT:"
llama.reset()
@ -890,11 +898,13 @@ class Llava15ChatHandler:
ctypes.c_ubyte * len(data_array)
).from_buffer(data_array)
with suppress_stdout_stderr(disable=self.verbose):
embed = self._llava_cpp.llava_image_embed_make_with_bytes(
ctx_clip=self.clip_ctx,
n_threads=llama.context_params.n_threads,
image_bytes=c_ubyte_ptr,
image_bytes_length=len(image_bytes),
embed = (
self._llava_cpp.llava_image_embed_make_with_bytes(
ctx_clip=self.clip_ctx,
n_threads=llama.context_params.n_threads,
image_bytes=c_ubyte_ptr,
image_bytes_length=len(image_bytes),
)
)
try:
n_past = ctypes.c_int(llama.n_tokens)
@ -917,9 +927,17 @@ class Llava15ChatHandler:
f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False
)
)
assert llama.n_ctx() >= llama.n_tokens
llama.eval(llama.tokenize(f"{assistant_role}".encode("utf8"), add_bos=False))
assert llama.n_ctx() >= llama.n_tokens
prompt = llama.input_ids[:llama.n_tokens].tolist()
prompt = llama.input_ids[: llama.n_tokens].tolist()
if response_format is not None and response_format["type"] == "json_object":
with suppress_stdout_stderr(disable=self.verbose):
grammar = llama_grammar.LlamaGrammar.from_string(
llama_grammar.JSON_GBNF
)
return _convert_completion_to_chat(
llama.create_completion(