From b30b9c338bf9af316d497ea501d39f5c246900db Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 8 Nov 2023 00:07:16 -0500 Subject: [PATCH] Add JSON mode support. Closes #881 --- llama_cpp/llama.py | 2 + llama_cpp/llama_chat_format.py | 144 ++++++++++++++++++++++++--------- llama_cpp/llama_types.py | 6 +- llama_cpp/server/app.py | 3 + 4 files changed, 116 insertions(+), 39 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index b1ba2a0..173f132 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1901,6 +1901,7 @@ class Llama: stream: bool = False, stop: Optional[Union[str, List[str]]] = [], seed: Optional[int] = None, + response_format: Optional[ChatCompletionRequestResponseFormat] = None, max_tokens: int = 256, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, @@ -1946,6 +1947,7 @@ class Llama: stream=stream, stop=stop, seed=seed, + response_format=response_format, max_tokens=max_tokens, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index c9a4775..512de6f 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -5,8 +5,9 @@ import ctypes import dataclasses from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol -import llama_cpp.llama_types as llama_types import llama_cpp.llama as llama +import llama_cpp.llama_types as llama_types +import llama_cpp.llama_grammar as llama_grammar class LlamaChatCompletionHandler(Protocol): @@ -25,6 +26,9 @@ class LlamaChatCompletionHandler(Protocol): stream: bool = False, stop: Optional[Union[str, List[str]]] = [], seed: Optional[int] = None, + response_format: Optional[ + llama_types.ChatCompletionRequestResponseFormat + ] = None, max_tokens: int = 256, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, @@ -37,7 +41,10 @@ class LlamaChatCompletionHandler(Protocol): logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, **kwargs, # type: ignore - ) -> Union[llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse]]: + ) -> Union[ + llama_types.CreateChatCompletionResponse, + Iterator[llama_types.CreateChatCompletionStreamResponse], + ]: ... @@ -169,6 +176,7 @@ class ChatFormatterResponse: class ChatFormatter(Protocol): def __call__( self, + *, messages: List[llama_types.ChatCompletionRequestMessage], **kwargs: Any, ) -> ChatFormatterResponse: @@ -264,17 +272,24 @@ _CHAT_FORMATS: Dict[str, ChatFormatter] = {} def register_chat_format(name: str): def decorator(f: ChatFormatter): def basic_create_chat_completion( + *, llama: llama.Llama, messages: List[llama_types.ChatCompletionRequestMessage], functions: Optional[List[llama_types.ChatCompletionFunction]] = None, function_call: Optional[ - Union[str, llama_types.ChatCompletionFunctionCall] + llama_types.ChatCompletionRequestFunctionCall ] = None, + tools: Optional[List[llama_types.ChatCompletionTool]] = None, + tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, stream: bool = False, stop: Optional[Union[str, List[str]]] = [], + seed: Optional[int] = None, + response_format: Optional[ + llama_types.ChatCompletionRequestResponseFormat + ] = None, max_tokens: int = 256, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, @@ -286,8 +301,10 @@ def register_chat_format(name: str): model: Optional[str] = None, logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, + **kwargs, # type: ignore ) -> Union[ - llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk] + llama_types.CreateChatCompletionResponse, + Iterator[llama_types.CreateChatCompletionStreamResponse], ]: result = f( messages=messages, @@ -299,6 +316,10 @@ def register_chat_format(name: str): stop = [] if stop is None else [stop] if isinstance(stop, str) else stop rstop = result.stop if isinstance(result.stop, list) else [result.stop] stop = stop + rstop + + if response_format is not None and response_format["type"] == "json_object": + print("hello world") + grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF) completion_or_chunks = llama.create_completion( prompt=prompt, @@ -307,6 +328,7 @@ def register_chat_format(name: str): top_k=top_k, stream=stream, stop=stop, + seed=seed, max_tokens=max_tokens, presence_penalty=presence_penalty, frequency_penalty=frequency_penalty, @@ -319,7 +341,7 @@ def register_chat_format(name: str): logits_processor=logits_processor, grammar=grammar, ) - return _convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore + return _convert_completion_to_chat(completion_or_chunks, stream=stream) register_chat_completion_handler(name)(basic_create_chat_completion) return f @@ -727,7 +749,7 @@ def functionary_chat_handler( assert "usage" in completion assert isinstance(function_call, str) - assert stream is False # TODO: support stream mode + assert stream is False # TODO: support stream mode return llama_types.CreateChatCompletionResponse( id="chat" + completion["id"], @@ -759,7 +781,9 @@ class Llava15ChatHandler: self._llava_cpp = llava_cpp self.clip_model_path = clip_model_path - self.clip_ctx = self._llava_cpp.clip_model_load(self.clip_model_path.encode(), 0) + self.clip_ctx = self._llava_cpp.clip_model_load( + self.clip_model_path.encode(), 0 + ) def __del__(self): if self.clip_ctx is not None: @@ -805,12 +829,21 @@ class Llava15ChatHandler: logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, **kwargs, # type: ignore - ) -> Union[llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse]]: - assert llama.context_params.logits_all is True # BUG: logits_all=True is required for llava + ) -> Union[ + llama_types.CreateChatCompletionResponse, + Iterator[llama_types.CreateChatCompletionStreamResponse], + ]: + assert ( + llama.context_params.logits_all is True + ) # BUG: logits_all=True is required for llava assert self.clip_ctx is not None system_prompt = _get_system_message(messages) - system_prompt = system_prompt if system_prompt != "" else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." - system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." + system_prompt = ( + system_prompt + if system_prompt != "" + else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." + ) + system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." user_role = "\nUSER:" assistant_role = "\nASSISTANT:" llama.reset() @@ -818,51 +851,86 @@ class Llava15ChatHandler: for message in messages: if message["role"] == "user" and message["content"] is not None: if isinstance(message["content"], str): - llama.eval(llama.tokenize(f"{user_role} {message['content']}".encode("utf8"), add_bos=False)) + llama.eval( + llama.tokenize( + f"{user_role} {message['content']}".encode("utf8"), + add_bos=False, + ) + ) else: assert isinstance(message["content"], list) - llama.eval(llama.tokenize(f"{user_role} ".encode("utf8"), add_bos=False)) + llama.eval( + llama.tokenize(f"{user_role} ".encode("utf8"), add_bos=False) + ) for content in message["content"]: if content["type"] == "text": - llama.eval(llama.tokenize(f"{content['text']}".encode("utf8"), add_bos=False)) + llama.eval( + llama.tokenize( + f"{content['text']}".encode("utf8"), add_bos=False + ) + ) if content["type"] == "image_url": - image_bytes = self.load_image(content["image_url"]["url"]) if isinstance(content["image_url"], dict) else self.load_image(content["image_url"]) + image_bytes = ( + self.load_image(content["image_url"]["url"]) + if isinstance(content["image_url"], dict) + else self.load_image(content["image_url"]) + ) import array - data_array = array.array('B', image_bytes) - c_ubyte_ptr = (ctypes.c_ubyte * len(data_array)).from_buffer(data_array) - embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=llama.context_params.n_threads, image_bytes=c_ubyte_ptr, image_bytes_length=len(image_bytes)) + + data_array = array.array("B", image_bytes) + c_ubyte_ptr = ( + ctypes.c_ubyte * len(data_array) + ).from_buffer(data_array) + embed = self._llava_cpp.llava_image_embed_make_with_bytes( + ctx_clip=self.clip_ctx, + n_threads=llama.context_params.n_threads, + image_bytes=c_ubyte_ptr, + image_bytes_length=len(image_bytes), + ) # image_bytes_p = (ctypes.c_uint8 * len(image_bytes)).from_buffer_copy(image_bytes) # embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=1, image_bytes=image_bytes_p, image_bytes_length=len(image_bytes)) try: n_past = ctypes.c_int(llama.n_tokens) n_past_p = ctypes.pointer(n_past) - self._llava_cpp.llava_eval_image_embed(ctx_llama=llama.ctx, embed=embed, n_batch=llama.n_batch, n_past=n_past_p) + self._llava_cpp.llava_eval_image_embed( + ctx_llama=llama.ctx, + embed=embed, + n_batch=llama.n_batch, + n_past=n_past_p, + ) assert llama.n_ctx() >= n_past.value llama.n_tokens = n_past.value finally: self._llava_cpp.llava_image_embed_free(embed) if message["role"] == "assistant" and message["content"] is not None: - llama.eval(llama.tokenize(f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False)) + llama.eval( + llama.tokenize( + f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False + ) + ) llama.eval(llama.tokenize(f"{assistant_role}".encode("utf8"), add_bos=False)) prompt = llama._input_ids.tolist() - return _convert_completion_to_chat(llama.create_completion( - prompt=prompt, - temperature=temperature, - top_p=top_p, - top_k=top_k, + return _convert_completion_to_chat( + llama.create_completion( + prompt=prompt, + temperature=temperature, + top_p=top_p, + top_k=top_k, + stream=stream, + stop=stop, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repeat_penalty=repeat_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + model=model, + logits_processor=logits_processor, + grammar=grammar, + ), stream=stream, - stop=stop, - max_tokens=max_tokens, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - repeat_penalty=repeat_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - model=model, - logits_processor=logits_processor, - grammar=grammar, - ), stream=stream) \ No newline at end of file + ) diff --git a/llama_cpp/llama_types.py b/llama_cpp/llama_types.py index 69d07fc..b1d7d01 100644 --- a/llama_cpp/llama_types.py +++ b/llama_cpp/llama_types.py @@ -152,6 +152,10 @@ class ChatCompletionFunctionCallOption(TypedDict): name: str +class ChatCompletionRequestResponseFormat(TypedDict): + type: Literal["text", "json_object"] + + class ChatCompletionRequestMessageContentPartText(TypedDict): type: Literal["text"] text: str @@ -241,7 +245,7 @@ ChatCompletionRequestFunctionCall = Union[ Literal["none", "auto"], ChatCompletionRequestFunctionCallOption ] -ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific +ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific class ChatCompletionToolFunction(TypedDict): diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 7caee89..e4e1891 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -792,6 +792,9 @@ class CreateChatCompletionRequest(BaseModel): frequency_penalty: Optional[float] = frequency_penalty_field logit_bias: Optional[Dict[str, float]] = Field(None) seed: Optional[int] = Field(None) + response_format: Optional[llama_cpp.ChatCompletionRequestResponseFormat] = Field( + default=None, + ) # ignored or currently unsupported model: Optional[str] = model_field