Add JSON mode support. Closes #881

This commit is contained in:
Andrei Betlen 2023-11-08 00:07:16 -05:00
parent 4852a6a39c
commit b30b9c338b
4 changed files with 116 additions and 39 deletions

View file

@ -1901,6 +1901,7 @@ class Llama:
stream: bool = False, stream: bool = False,
stop: Optional[Union[str, List[str]]] = [], stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None, seed: Optional[int] = None,
response_format: Optional[ChatCompletionRequestResponseFormat] = None,
max_tokens: int = 256, max_tokens: int = 256,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
@ -1946,6 +1947,7 @@ class Llama:
stream=stream, stream=stream,
stop=stop, stop=stop,
seed=seed, seed=seed,
response_format=response_format,
max_tokens=max_tokens, max_tokens=max_tokens,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,

View file

@ -5,8 +5,9 @@ import ctypes
import dataclasses import dataclasses
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
import llama_cpp.llama_types as llama_types
import llama_cpp.llama as llama import llama_cpp.llama as llama
import llama_cpp.llama_types as llama_types
import llama_cpp.llama_grammar as llama_grammar
class LlamaChatCompletionHandler(Protocol): class LlamaChatCompletionHandler(Protocol):
@ -25,6 +26,9 @@ class LlamaChatCompletionHandler(Protocol):
stream: bool = False, stream: bool = False,
stop: Optional[Union[str, List[str]]] = [], stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None, seed: Optional[int] = None,
response_format: Optional[
llama_types.ChatCompletionRequestResponseFormat
] = None,
max_tokens: int = 256, max_tokens: int = 256,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
@ -37,7 +41,10 @@ class LlamaChatCompletionHandler(Protocol):
logits_processor: Optional[llama.LogitsProcessorList] = None, logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None, grammar: Optional[llama.LlamaGrammar] = None,
**kwargs, # type: ignore **kwargs, # type: ignore
) -> Union[llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse]]: ) -> Union[
llama_types.CreateChatCompletionResponse,
Iterator[llama_types.CreateChatCompletionStreamResponse],
]:
... ...
@ -169,6 +176,7 @@ class ChatFormatterResponse:
class ChatFormatter(Protocol): class ChatFormatter(Protocol):
def __call__( def __call__(
self, self,
*,
messages: List[llama_types.ChatCompletionRequestMessage], messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any, **kwargs: Any,
) -> ChatFormatterResponse: ) -> ChatFormatterResponse:
@ -264,17 +272,24 @@ _CHAT_FORMATS: Dict[str, ChatFormatter] = {}
def register_chat_format(name: str): def register_chat_format(name: str):
def decorator(f: ChatFormatter): def decorator(f: ChatFormatter):
def basic_create_chat_completion( def basic_create_chat_completion(
*,
llama: llama.Llama, llama: llama.Llama,
messages: List[llama_types.ChatCompletionRequestMessage], messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunction]] = None, functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
function_call: Optional[ function_call: Optional[
Union[str, llama_types.ChatCompletionFunctionCall] llama_types.ChatCompletionRequestFunctionCall
] = None, ] = None,
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
temperature: float = 0.2, temperature: float = 0.2,
top_p: float = 0.95, top_p: float = 0.95,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
stop: Optional[Union[str, List[str]]] = [], stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None,
response_format: Optional[
llama_types.ChatCompletionRequestResponseFormat
] = None,
max_tokens: int = 256, max_tokens: int = 256,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
@ -286,8 +301,10 @@ def register_chat_format(name: str):
model: Optional[str] = None, model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None, logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None, grammar: Optional[llama.LlamaGrammar] = None,
**kwargs, # type: ignore
) -> Union[ ) -> Union[
llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk] llama_types.CreateChatCompletionResponse,
Iterator[llama_types.CreateChatCompletionStreamResponse],
]: ]:
result = f( result = f(
messages=messages, messages=messages,
@ -300,6 +317,10 @@ def register_chat_format(name: str):
rstop = result.stop if isinstance(result.stop, list) else [result.stop] rstop = result.stop if isinstance(result.stop, list) else [result.stop]
stop = stop + rstop stop = stop + rstop
if response_format is not None and response_format["type"] == "json_object":
print("hello world")
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
completion_or_chunks = llama.create_completion( completion_or_chunks = llama.create_completion(
prompt=prompt, prompt=prompt,
temperature=temperature, temperature=temperature,
@ -307,6 +328,7 @@ def register_chat_format(name: str):
top_k=top_k, top_k=top_k,
stream=stream, stream=stream,
stop=stop, stop=stop,
seed=seed,
max_tokens=max_tokens, max_tokens=max_tokens,
presence_penalty=presence_penalty, presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
@ -319,7 +341,7 @@ def register_chat_format(name: str):
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar, grammar=grammar,
) )
return _convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore return _convert_completion_to_chat(completion_or_chunks, stream=stream)
register_chat_completion_handler(name)(basic_create_chat_completion) register_chat_completion_handler(name)(basic_create_chat_completion)
return f return f
@ -727,7 +749,7 @@ def functionary_chat_handler(
assert "usage" in completion assert "usage" in completion
assert isinstance(function_call, str) assert isinstance(function_call, str)
assert stream is False # TODO: support stream mode assert stream is False # TODO: support stream mode
return llama_types.CreateChatCompletionResponse( return llama_types.CreateChatCompletionResponse(
id="chat" + completion["id"], id="chat" + completion["id"],
@ -759,7 +781,9 @@ class Llava15ChatHandler:
self._llava_cpp = llava_cpp self._llava_cpp = llava_cpp
self.clip_model_path = clip_model_path self.clip_model_path = clip_model_path
self.clip_ctx = self._llava_cpp.clip_model_load(self.clip_model_path.encode(), 0) self.clip_ctx = self._llava_cpp.clip_model_load(
self.clip_model_path.encode(), 0
)
def __del__(self): def __del__(self):
if self.clip_ctx is not None: if self.clip_ctx is not None:
@ -805,12 +829,21 @@ class Llava15ChatHandler:
logits_processor: Optional[llama.LogitsProcessorList] = None, logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None, grammar: Optional[llama.LlamaGrammar] = None,
**kwargs, # type: ignore **kwargs, # type: ignore
) -> Union[llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse]]: ) -> Union[
assert llama.context_params.logits_all is True # BUG: logits_all=True is required for llava llama_types.CreateChatCompletionResponse,
Iterator[llama_types.CreateChatCompletionStreamResponse],
]:
assert (
llama.context_params.logits_all is True
) # BUG: logits_all=True is required for llava
assert self.clip_ctx is not None assert self.clip_ctx is not None
system_prompt = _get_system_message(messages) system_prompt = _get_system_message(messages)
system_prompt = system_prompt if system_prompt != "" else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." system_prompt = (
system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions." system_prompt
if system_prompt != ""
else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
)
system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
user_role = "\nUSER:" user_role = "\nUSER:"
assistant_role = "\nASSISTANT:" assistant_role = "\nASSISTANT:"
llama.reset() llama.reset()
@ -818,51 +851,86 @@ class Llava15ChatHandler:
for message in messages: for message in messages:
if message["role"] == "user" and message["content"] is not None: if message["role"] == "user" and message["content"] is not None:
if isinstance(message["content"], str): if isinstance(message["content"], str):
llama.eval(llama.tokenize(f"{user_role} {message['content']}".encode("utf8"), add_bos=False)) llama.eval(
llama.tokenize(
f"{user_role} {message['content']}".encode("utf8"),
add_bos=False,
)
)
else: else:
assert isinstance(message["content"], list) assert isinstance(message["content"], list)
llama.eval(llama.tokenize(f"{user_role} ".encode("utf8"), add_bos=False)) llama.eval(
llama.tokenize(f"{user_role} ".encode("utf8"), add_bos=False)
)
for content in message["content"]: for content in message["content"]:
if content["type"] == "text": if content["type"] == "text":
llama.eval(llama.tokenize(f"{content['text']}".encode("utf8"), add_bos=False)) llama.eval(
llama.tokenize(
f"{content['text']}".encode("utf8"), add_bos=False
)
)
if content["type"] == "image_url": if content["type"] == "image_url":
image_bytes = self.load_image(content["image_url"]["url"]) if isinstance(content["image_url"], dict) else self.load_image(content["image_url"]) image_bytes = (
self.load_image(content["image_url"]["url"])
if isinstance(content["image_url"], dict)
else self.load_image(content["image_url"])
)
import array import array
data_array = array.array('B', image_bytes)
c_ubyte_ptr = (ctypes.c_ubyte * len(data_array)).from_buffer(data_array) data_array = array.array("B", image_bytes)
embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=llama.context_params.n_threads, image_bytes=c_ubyte_ptr, image_bytes_length=len(image_bytes)) c_ubyte_ptr = (
ctypes.c_ubyte * len(data_array)
).from_buffer(data_array)
embed = self._llava_cpp.llava_image_embed_make_with_bytes(
ctx_clip=self.clip_ctx,
n_threads=llama.context_params.n_threads,
image_bytes=c_ubyte_ptr,
image_bytes_length=len(image_bytes),
)
# image_bytes_p = (ctypes.c_uint8 * len(image_bytes)).from_buffer_copy(image_bytes) # image_bytes_p = (ctypes.c_uint8 * len(image_bytes)).from_buffer_copy(image_bytes)
# embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=1, image_bytes=image_bytes_p, image_bytes_length=len(image_bytes)) # embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=1, image_bytes=image_bytes_p, image_bytes_length=len(image_bytes))
try: try:
n_past = ctypes.c_int(llama.n_tokens) n_past = ctypes.c_int(llama.n_tokens)
n_past_p = ctypes.pointer(n_past) n_past_p = ctypes.pointer(n_past)
self._llava_cpp.llava_eval_image_embed(ctx_llama=llama.ctx, embed=embed, n_batch=llama.n_batch, n_past=n_past_p) self._llava_cpp.llava_eval_image_embed(
ctx_llama=llama.ctx,
embed=embed,
n_batch=llama.n_batch,
n_past=n_past_p,
)
assert llama.n_ctx() >= n_past.value assert llama.n_ctx() >= n_past.value
llama.n_tokens = n_past.value llama.n_tokens = n_past.value
finally: finally:
self._llava_cpp.llava_image_embed_free(embed) self._llava_cpp.llava_image_embed_free(embed)
if message["role"] == "assistant" and message["content"] is not None: if message["role"] == "assistant" and message["content"] is not None:
llama.eval(llama.tokenize(f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False)) llama.eval(
llama.tokenize(
f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False
)
)
llama.eval(llama.tokenize(f"{assistant_role}".encode("utf8"), add_bos=False)) llama.eval(llama.tokenize(f"{assistant_role}".encode("utf8"), add_bos=False))
prompt = llama._input_ids.tolist() prompt = llama._input_ids.tolist()
return _convert_completion_to_chat(llama.create_completion( return _convert_completion_to_chat(
prompt=prompt, llama.create_completion(
temperature=temperature, prompt=prompt,
top_p=top_p, temperature=temperature,
top_k=top_k, top_p=top_p,
top_k=top_k,
stream=stream,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
grammar=grammar,
),
stream=stream, stream=stream,
stop=stop, )
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
grammar=grammar,
), stream=stream)

View file

@ -152,6 +152,10 @@ class ChatCompletionFunctionCallOption(TypedDict):
name: str name: str
class ChatCompletionRequestResponseFormat(TypedDict):
type: Literal["text", "json_object"]
class ChatCompletionRequestMessageContentPartText(TypedDict): class ChatCompletionRequestMessageContentPartText(TypedDict):
type: Literal["text"] type: Literal["text"]
text: str text: str
@ -241,7 +245,7 @@ ChatCompletionRequestFunctionCall = Union[
Literal["none", "auto"], ChatCompletionRequestFunctionCallOption Literal["none", "auto"], ChatCompletionRequestFunctionCallOption
] ]
ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific
class ChatCompletionToolFunction(TypedDict): class ChatCompletionToolFunction(TypedDict):

View file

@ -792,6 +792,9 @@ class CreateChatCompletionRequest(BaseModel):
frequency_penalty: Optional[float] = frequency_penalty_field frequency_penalty: Optional[float] = frequency_penalty_field
logit_bias: Optional[Dict[str, float]] = Field(None) logit_bias: Optional[Dict[str, float]] = Field(None)
seed: Optional[int] = Field(None) seed: Optional[int] = Field(None)
response_format: Optional[llama_cpp.ChatCompletionRequestResponseFormat] = Field(
default=None,
)
# ignored or currently unsupported # ignored or currently unsupported
model: Optional[str] = model_field model: Optional[str] = model_field