Add JSON mode support. Closes #881
This commit is contained in:
parent
4852a6a39c
commit
b30b9c338b
4 changed files with 116 additions and 39 deletions
|
@ -1901,6 +1901,7 @@ class Llama:
|
|||
stream: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
seed: Optional[int] = None,
|
||||
response_format: Optional[ChatCompletionRequestResponseFormat] = None,
|
||||
max_tokens: int = 256,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
|
@ -1946,6 +1947,7 @@ class Llama:
|
|||
stream=stream,
|
||||
stop=stop,
|
||||
seed=seed,
|
||||
response_format=response_format,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
|
|
|
@ -5,8 +5,9 @@ import ctypes
|
|||
import dataclasses
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
|
||||
|
||||
import llama_cpp.llama_types as llama_types
|
||||
import llama_cpp.llama as llama
|
||||
import llama_cpp.llama_types as llama_types
|
||||
import llama_cpp.llama_grammar as llama_grammar
|
||||
|
||||
|
||||
class LlamaChatCompletionHandler(Protocol):
|
||||
|
@ -25,6 +26,9 @@ class LlamaChatCompletionHandler(Protocol):
|
|||
stream: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
seed: Optional[int] = None,
|
||||
response_format: Optional[
|
||||
llama_types.ChatCompletionRequestResponseFormat
|
||||
] = None,
|
||||
max_tokens: int = 256,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
|
@ -37,7 +41,10 @@ class LlamaChatCompletionHandler(Protocol):
|
|||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
**kwargs, # type: ignore
|
||||
) -> Union[llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse]]:
|
||||
) -> Union[
|
||||
llama_types.CreateChatCompletionResponse,
|
||||
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
||||
]:
|
||||
...
|
||||
|
||||
|
||||
|
@ -169,6 +176,7 @@ class ChatFormatterResponse:
|
|||
class ChatFormatter(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
*,
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
**kwargs: Any,
|
||||
) -> ChatFormatterResponse:
|
||||
|
@ -264,17 +272,24 @@ _CHAT_FORMATS: Dict[str, ChatFormatter] = {}
|
|||
def register_chat_format(name: str):
|
||||
def decorator(f: ChatFormatter):
|
||||
def basic_create_chat_completion(
|
||||
*,
|
||||
llama: llama.Llama,
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||
function_call: Optional[
|
||||
Union[str, llama_types.ChatCompletionFunctionCall]
|
||||
llama_types.ChatCompletionRequestFunctionCall
|
||||
] = None,
|
||||
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
|
||||
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
|
||||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
seed: Optional[int] = None,
|
||||
response_format: Optional[
|
||||
llama_types.ChatCompletionRequestResponseFormat
|
||||
] = None,
|
||||
max_tokens: int = 256,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
|
@ -286,8 +301,10 @@ def register_chat_format(name: str):
|
|||
model: Optional[str] = None,
|
||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
**kwargs, # type: ignore
|
||||
) -> Union[
|
||||
llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]
|
||||
llama_types.CreateChatCompletionResponse,
|
||||
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
||||
]:
|
||||
result = f(
|
||||
messages=messages,
|
||||
|
@ -300,6 +317,10 @@ def register_chat_format(name: str):
|
|||
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
|
||||
stop = stop + rstop
|
||||
|
||||
if response_format is not None and response_format["type"] == "json_object":
|
||||
print("hello world")
|
||||
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
|
||||
|
||||
completion_or_chunks = llama.create_completion(
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
|
@ -307,6 +328,7 @@ def register_chat_format(name: str):
|
|||
top_k=top_k,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
seed=seed,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
|
@ -319,7 +341,7 @@ def register_chat_format(name: str):
|
|||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
)
|
||||
return _convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore
|
||||
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
|
||||
|
||||
register_chat_completion_handler(name)(basic_create_chat_completion)
|
||||
return f
|
||||
|
@ -727,7 +749,7 @@ def functionary_chat_handler(
|
|||
|
||||
assert "usage" in completion
|
||||
assert isinstance(function_call, str)
|
||||
assert stream is False # TODO: support stream mode
|
||||
assert stream is False # TODO: support stream mode
|
||||
|
||||
return llama_types.CreateChatCompletionResponse(
|
||||
id="chat" + completion["id"],
|
||||
|
@ -759,7 +781,9 @@ class Llava15ChatHandler:
|
|||
self._llava_cpp = llava_cpp
|
||||
self.clip_model_path = clip_model_path
|
||||
|
||||
self.clip_ctx = self._llava_cpp.clip_model_load(self.clip_model_path.encode(), 0)
|
||||
self.clip_ctx = self._llava_cpp.clip_model_load(
|
||||
self.clip_model_path.encode(), 0
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
if self.clip_ctx is not None:
|
||||
|
@ -805,12 +829,21 @@ class Llava15ChatHandler:
|
|||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
**kwargs, # type: ignore
|
||||
) -> Union[llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse]]:
|
||||
assert llama.context_params.logits_all is True # BUG: logits_all=True is required for llava
|
||||
) -> Union[
|
||||
llama_types.CreateChatCompletionResponse,
|
||||
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
||||
]:
|
||||
assert (
|
||||
llama.context_params.logits_all is True
|
||||
) # BUG: logits_all=True is required for llava
|
||||
assert self.clip_ctx is not None
|
||||
system_prompt = _get_system_message(messages)
|
||||
system_prompt = system_prompt if system_prompt != "" else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
system_prompt = (
|
||||
system_prompt
|
||||
if system_prompt != ""
|
||||
else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
)
|
||||
system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||
user_role = "\nUSER:"
|
||||
assistant_role = "\nASSISTANT:"
|
||||
llama.reset()
|
||||
|
@ -818,51 +851,86 @@ class Llava15ChatHandler:
|
|||
for message in messages:
|
||||
if message["role"] == "user" and message["content"] is not None:
|
||||
if isinstance(message["content"], str):
|
||||
llama.eval(llama.tokenize(f"{user_role} {message['content']}".encode("utf8"), add_bos=False))
|
||||
llama.eval(
|
||||
llama.tokenize(
|
||||
f"{user_role} {message['content']}".encode("utf8"),
|
||||
add_bos=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert isinstance(message["content"], list)
|
||||
llama.eval(llama.tokenize(f"{user_role} ".encode("utf8"), add_bos=False))
|
||||
llama.eval(
|
||||
llama.tokenize(f"{user_role} ".encode("utf8"), add_bos=False)
|
||||
)
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
llama.eval(llama.tokenize(f"{content['text']}".encode("utf8"), add_bos=False))
|
||||
llama.eval(
|
||||
llama.tokenize(
|
||||
f"{content['text']}".encode("utf8"), add_bos=False
|
||||
)
|
||||
)
|
||||
if content["type"] == "image_url":
|
||||
image_bytes = self.load_image(content["image_url"]["url"]) if isinstance(content["image_url"], dict) else self.load_image(content["image_url"])
|
||||
image_bytes = (
|
||||
self.load_image(content["image_url"]["url"])
|
||||
if isinstance(content["image_url"], dict)
|
||||
else self.load_image(content["image_url"])
|
||||
)
|
||||
import array
|
||||
data_array = array.array('B', image_bytes)
|
||||
c_ubyte_ptr = (ctypes.c_ubyte * len(data_array)).from_buffer(data_array)
|
||||
embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=llama.context_params.n_threads, image_bytes=c_ubyte_ptr, image_bytes_length=len(image_bytes))
|
||||
|
||||
data_array = array.array("B", image_bytes)
|
||||
c_ubyte_ptr = (
|
||||
ctypes.c_ubyte * len(data_array)
|
||||
).from_buffer(data_array)
|
||||
embed = self._llava_cpp.llava_image_embed_make_with_bytes(
|
||||
ctx_clip=self.clip_ctx,
|
||||
n_threads=llama.context_params.n_threads,
|
||||
image_bytes=c_ubyte_ptr,
|
||||
image_bytes_length=len(image_bytes),
|
||||
)
|
||||
# image_bytes_p = (ctypes.c_uint8 * len(image_bytes)).from_buffer_copy(image_bytes)
|
||||
# embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=1, image_bytes=image_bytes_p, image_bytes_length=len(image_bytes))
|
||||
try:
|
||||
n_past = ctypes.c_int(llama.n_tokens)
|
||||
n_past_p = ctypes.pointer(n_past)
|
||||
self._llava_cpp.llava_eval_image_embed(ctx_llama=llama.ctx, embed=embed, n_batch=llama.n_batch, n_past=n_past_p)
|
||||
self._llava_cpp.llava_eval_image_embed(
|
||||
ctx_llama=llama.ctx,
|
||||
embed=embed,
|
||||
n_batch=llama.n_batch,
|
||||
n_past=n_past_p,
|
||||
)
|
||||
assert llama.n_ctx() >= n_past.value
|
||||
llama.n_tokens = n_past.value
|
||||
finally:
|
||||
self._llava_cpp.llava_image_embed_free(embed)
|
||||
if message["role"] == "assistant" and message["content"] is not None:
|
||||
llama.eval(llama.tokenize(f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False))
|
||||
llama.eval(
|
||||
llama.tokenize(
|
||||
f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False
|
||||
)
|
||||
)
|
||||
llama.eval(llama.tokenize(f"{assistant_role}".encode("utf8"), add_bos=False))
|
||||
|
||||
prompt = llama._input_ids.tolist()
|
||||
|
||||
return _convert_completion_to_chat(llama.create_completion(
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
return _convert_completion_to_chat(
|
||||
llama.create_completion(
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
repeat_penalty=repeat_penalty,
|
||||
tfs_z=tfs_z,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
model=model,
|
||||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
),
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
repeat_penalty=repeat_penalty,
|
||||
tfs_z=tfs_z,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
model=model,
|
||||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
), stream=stream)
|
||||
)
|
||||
|
|
|
@ -152,6 +152,10 @@ class ChatCompletionFunctionCallOption(TypedDict):
|
|||
name: str
|
||||
|
||||
|
||||
class ChatCompletionRequestResponseFormat(TypedDict):
|
||||
type: Literal["text", "json_object"]
|
||||
|
||||
|
||||
class ChatCompletionRequestMessageContentPartText(TypedDict):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
@ -241,7 +245,7 @@ ChatCompletionRequestFunctionCall = Union[
|
|||
Literal["none", "auto"], ChatCompletionRequestFunctionCallOption
|
||||
]
|
||||
|
||||
ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific
|
||||
ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific
|
||||
|
||||
|
||||
class ChatCompletionToolFunction(TypedDict):
|
||||
|
|
|
@ -792,6 +792,9 @@ class CreateChatCompletionRequest(BaseModel):
|
|||
frequency_penalty: Optional[float] = frequency_penalty_field
|
||||
logit_bias: Optional[Dict[str, float]] = Field(None)
|
||||
seed: Optional[int] = Field(None)
|
||||
response_format: Optional[llama_cpp.ChatCompletionRequestResponseFormat] = Field(
|
||||
default=None,
|
||||
)
|
||||
|
||||
# ignored or currently unsupported
|
||||
model: Optional[str] = model_field
|
||||
|
|
Loading…
Reference in a new issue