Add JSON mode support. Closes #881

This commit is contained in:
Andrei Betlen 2023-11-08 00:07:16 -05:00
parent 4852a6a39c
commit b30b9c338b
4 changed files with 116 additions and 39 deletions

View file

@ -1901,6 +1901,7 @@ class Llama:
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None,
response_format: Optional[ChatCompletionRequestResponseFormat] = None,
max_tokens: int = 256,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
@ -1946,6 +1947,7 @@ class Llama:
stream=stream,
stop=stop,
seed=seed,
response_format=response_format,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,

View file

@ -5,8 +5,9 @@ import ctypes
import dataclasses
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
import llama_cpp.llama_types as llama_types
import llama_cpp.llama as llama
import llama_cpp.llama_types as llama_types
import llama_cpp.llama_grammar as llama_grammar
class LlamaChatCompletionHandler(Protocol):
@ -25,6 +26,9 @@ class LlamaChatCompletionHandler(Protocol):
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None,
response_format: Optional[
llama_types.ChatCompletionRequestResponseFormat
] = None,
max_tokens: int = 256,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
@ -37,7 +41,10 @@ class LlamaChatCompletionHandler(Protocol):
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
**kwargs, # type: ignore
) -> Union[llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse]]:
) -> Union[
llama_types.CreateChatCompletionResponse,
Iterator[llama_types.CreateChatCompletionStreamResponse],
]:
...
@ -169,6 +176,7 @@ class ChatFormatterResponse:
class ChatFormatter(Protocol):
def __call__(
self,
*,
messages: List[llama_types.ChatCompletionRequestMessage],
**kwargs: Any,
) -> ChatFormatterResponse:
@ -264,17 +272,24 @@ _CHAT_FORMATS: Dict[str, ChatFormatter] = {}
def register_chat_format(name: str):
def decorator(f: ChatFormatter):
def basic_create_chat_completion(
*,
llama: llama.Llama,
messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
function_call: Optional[
Union[str, llama_types.ChatCompletionFunctionCall]
llama_types.ChatCompletionRequestFunctionCall
] = None,
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
seed: Optional[int] = None,
response_format: Optional[
llama_types.ChatCompletionRequestResponseFormat
] = None,
max_tokens: int = 256,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
@ -286,8 +301,10 @@ def register_chat_format(name: str):
model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
**kwargs, # type: ignore
) -> Union[
llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]
llama_types.CreateChatCompletionResponse,
Iterator[llama_types.CreateChatCompletionStreamResponse],
]:
result = f(
messages=messages,
@ -299,6 +316,10 @@ def register_chat_format(name: str):
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
stop = stop + rstop
if response_format is not None and response_format["type"] == "json_object":
print("hello world")
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
completion_or_chunks = llama.create_completion(
prompt=prompt,
@ -307,6 +328,7 @@ def register_chat_format(name: str):
top_k=top_k,
stream=stream,
stop=stop,
seed=seed,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
@ -319,7 +341,7 @@ def register_chat_format(name: str):
logits_processor=logits_processor,
grammar=grammar,
)
return _convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
register_chat_completion_handler(name)(basic_create_chat_completion)
return f
@ -727,7 +749,7 @@ def functionary_chat_handler(
assert "usage" in completion
assert isinstance(function_call, str)
assert stream is False # TODO: support stream mode
assert stream is False # TODO: support stream mode
return llama_types.CreateChatCompletionResponse(
id="chat" + completion["id"],
@ -759,7 +781,9 @@ class Llava15ChatHandler:
self._llava_cpp = llava_cpp
self.clip_model_path = clip_model_path
self.clip_ctx = self._llava_cpp.clip_model_load(self.clip_model_path.encode(), 0)
self.clip_ctx = self._llava_cpp.clip_model_load(
self.clip_model_path.encode(), 0
)
def __del__(self):
if self.clip_ctx is not None:
@ -805,12 +829,21 @@ class Llava15ChatHandler:
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
**kwargs, # type: ignore
) -> Union[llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse]]:
assert llama.context_params.logits_all is True # BUG: logits_all=True is required for llava
) -> Union[
llama_types.CreateChatCompletionResponse,
Iterator[llama_types.CreateChatCompletionStreamResponse],
]:
assert (
llama.context_params.logits_all is True
) # BUG: logits_all=True is required for llava
assert self.clip_ctx is not None
system_prompt = _get_system_message(messages)
system_prompt = system_prompt if system_prompt != "" else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
system_prompt = (
system_prompt
if system_prompt != ""
else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
)
system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
user_role = "\nUSER:"
assistant_role = "\nASSISTANT:"
llama.reset()
@ -818,51 +851,86 @@ class Llava15ChatHandler:
for message in messages:
if message["role"] == "user" and message["content"] is not None:
if isinstance(message["content"], str):
llama.eval(llama.tokenize(f"{user_role} {message['content']}".encode("utf8"), add_bos=False))
llama.eval(
llama.tokenize(
f"{user_role} {message['content']}".encode("utf8"),
add_bos=False,
)
)
else:
assert isinstance(message["content"], list)
llama.eval(llama.tokenize(f"{user_role} ".encode("utf8"), add_bos=False))
llama.eval(
llama.tokenize(f"{user_role} ".encode("utf8"), add_bos=False)
)
for content in message["content"]:
if content["type"] == "text":
llama.eval(llama.tokenize(f"{content['text']}".encode("utf8"), add_bos=False))
llama.eval(
llama.tokenize(
f"{content['text']}".encode("utf8"), add_bos=False
)
)
if content["type"] == "image_url":
image_bytes = self.load_image(content["image_url"]["url"]) if isinstance(content["image_url"], dict) else self.load_image(content["image_url"])
image_bytes = (
self.load_image(content["image_url"]["url"])
if isinstance(content["image_url"], dict)
else self.load_image(content["image_url"])
)
import array
data_array = array.array('B', image_bytes)
c_ubyte_ptr = (ctypes.c_ubyte * len(data_array)).from_buffer(data_array)
embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=llama.context_params.n_threads, image_bytes=c_ubyte_ptr, image_bytes_length=len(image_bytes))
data_array = array.array("B", image_bytes)
c_ubyte_ptr = (
ctypes.c_ubyte * len(data_array)
).from_buffer(data_array)
embed = self._llava_cpp.llava_image_embed_make_with_bytes(
ctx_clip=self.clip_ctx,
n_threads=llama.context_params.n_threads,
image_bytes=c_ubyte_ptr,
image_bytes_length=len(image_bytes),
)
# image_bytes_p = (ctypes.c_uint8 * len(image_bytes)).from_buffer_copy(image_bytes)
# embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=1, image_bytes=image_bytes_p, image_bytes_length=len(image_bytes))
try:
n_past = ctypes.c_int(llama.n_tokens)
n_past_p = ctypes.pointer(n_past)
self._llava_cpp.llava_eval_image_embed(ctx_llama=llama.ctx, embed=embed, n_batch=llama.n_batch, n_past=n_past_p)
self._llava_cpp.llava_eval_image_embed(
ctx_llama=llama.ctx,
embed=embed,
n_batch=llama.n_batch,
n_past=n_past_p,
)
assert llama.n_ctx() >= n_past.value
llama.n_tokens = n_past.value
finally:
self._llava_cpp.llava_image_embed_free(embed)
if message["role"] == "assistant" and message["content"] is not None:
llama.eval(llama.tokenize(f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False))
llama.eval(
llama.tokenize(
f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False
)
)
llama.eval(llama.tokenize(f"{assistant_role}".encode("utf8"), add_bos=False))
prompt = llama._input_ids.tolist()
return _convert_completion_to_chat(llama.create_completion(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
return _convert_completion_to_chat(
llama.create_completion(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream=stream,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
grammar=grammar,
),
stream=stream,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
grammar=grammar,
), stream=stream)
)

View file

@ -152,6 +152,10 @@ class ChatCompletionFunctionCallOption(TypedDict):
name: str
class ChatCompletionRequestResponseFormat(TypedDict):
type: Literal["text", "json_object"]
class ChatCompletionRequestMessageContentPartText(TypedDict):
type: Literal["text"]
text: str
@ -241,7 +245,7 @@ ChatCompletionRequestFunctionCall = Union[
Literal["none", "auto"], ChatCompletionRequestFunctionCallOption
]
ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific
ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific
class ChatCompletionToolFunction(TypedDict):

View file

@ -792,6 +792,9 @@ class CreateChatCompletionRequest(BaseModel):
frequency_penalty: Optional[float] = frequency_penalty_field
logit_bias: Optional[Dict[str, float]] = Field(None)
seed: Optional[int] = Field(None)
response_format: Optional[llama_cpp.ChatCompletionRequestResponseFormat] = Field(
default=None,
)
# ignored or currently unsupported
model: Optional[str] = model_field