Add JSON mode support. Closes #881
This commit is contained in:
parent
4852a6a39c
commit
b30b9c338b
4 changed files with 116 additions and 39 deletions
|
@ -1901,6 +1901,7 @@ class Llama:
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
stop: Optional[Union[str, List[str]]] = [],
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
response_format: Optional[ChatCompletionRequestResponseFormat] = None,
|
||||||
max_tokens: int = 256,
|
max_tokens: int = 256,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
|
@ -1946,6 +1947,7 @@ class Llama:
|
||||||
stream=stream,
|
stream=stream,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
response_format=response_format,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
|
|
|
@ -5,8 +5,9 @@ import ctypes
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
|
||||||
|
|
||||||
import llama_cpp.llama_types as llama_types
|
|
||||||
import llama_cpp.llama as llama
|
import llama_cpp.llama as llama
|
||||||
|
import llama_cpp.llama_types as llama_types
|
||||||
|
import llama_cpp.llama_grammar as llama_grammar
|
||||||
|
|
||||||
|
|
||||||
class LlamaChatCompletionHandler(Protocol):
|
class LlamaChatCompletionHandler(Protocol):
|
||||||
|
@ -25,6 +26,9 @@ class LlamaChatCompletionHandler(Protocol):
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
stop: Optional[Union[str, List[str]]] = [],
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
|
response_format: Optional[
|
||||||
|
llama_types.ChatCompletionRequestResponseFormat
|
||||||
|
] = None,
|
||||||
max_tokens: int = 256,
|
max_tokens: int = 256,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
|
@ -37,7 +41,10 @@ class LlamaChatCompletionHandler(Protocol):
|
||||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||||
grammar: Optional[llama.LlamaGrammar] = None,
|
grammar: Optional[llama.LlamaGrammar] = None,
|
||||||
**kwargs, # type: ignore
|
**kwargs, # type: ignore
|
||||||
) -> Union[llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse]]:
|
) -> Union[
|
||||||
|
llama_types.CreateChatCompletionResponse,
|
||||||
|
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
||||||
|
]:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
@ -169,6 +176,7 @@ class ChatFormatterResponse:
|
||||||
class ChatFormatter(Protocol):
|
class ChatFormatter(Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
*,
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatFormatterResponse:
|
) -> ChatFormatterResponse:
|
||||||
|
@ -264,17 +272,24 @@ _CHAT_FORMATS: Dict[str, ChatFormatter] = {}
|
||||||
def register_chat_format(name: str):
|
def register_chat_format(name: str):
|
||||||
def decorator(f: ChatFormatter):
|
def decorator(f: ChatFormatter):
|
||||||
def basic_create_chat_completion(
|
def basic_create_chat_completion(
|
||||||
|
*,
|
||||||
llama: llama.Llama,
|
llama: llama.Llama,
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||||
function_call: Optional[
|
function_call: Optional[
|
||||||
Union[str, llama_types.ChatCompletionFunctionCall]
|
llama_types.ChatCompletionRequestFunctionCall
|
||||||
] = None,
|
] = None,
|
||||||
|
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
|
||||||
|
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
|
||||||
temperature: float = 0.2,
|
temperature: float = 0.2,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
stop: Optional[Union[str, List[str]]] = [],
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
response_format: Optional[
|
||||||
|
llama_types.ChatCompletionRequestResponseFormat
|
||||||
|
] = None,
|
||||||
max_tokens: int = 256,
|
max_tokens: int = 256,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
|
@ -286,8 +301,10 @@ def register_chat_format(name: str):
|
||||||
model: Optional[str] = None,
|
model: Optional[str] = None,
|
||||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||||
grammar: Optional[llama.LlamaGrammar] = None,
|
grammar: Optional[llama.LlamaGrammar] = None,
|
||||||
|
**kwargs, # type: ignore
|
||||||
) -> Union[
|
) -> Union[
|
||||||
llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]
|
llama_types.CreateChatCompletionResponse,
|
||||||
|
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
||||||
]:
|
]:
|
||||||
result = f(
|
result = f(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -300,6 +317,10 @@ def register_chat_format(name: str):
|
||||||
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
|
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
|
||||||
stop = stop + rstop
|
stop = stop + rstop
|
||||||
|
|
||||||
|
if response_format is not None and response_format["type"] == "json_object":
|
||||||
|
print("hello world")
|
||||||
|
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
|
||||||
|
|
||||||
completion_or_chunks = llama.create_completion(
|
completion_or_chunks = llama.create_completion(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
@ -307,6 +328,7 @@ def register_chat_format(name: str):
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
stop=stop,
|
stop=stop,
|
||||||
|
seed=seed,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
presence_penalty=presence_penalty,
|
presence_penalty=presence_penalty,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
|
@ -319,7 +341,7 @@ def register_chat_format(name: str):
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
)
|
)
|
||||||
return _convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore
|
return _convert_completion_to_chat(completion_or_chunks, stream=stream)
|
||||||
|
|
||||||
register_chat_completion_handler(name)(basic_create_chat_completion)
|
register_chat_completion_handler(name)(basic_create_chat_completion)
|
||||||
return f
|
return f
|
||||||
|
@ -727,7 +749,7 @@ def functionary_chat_handler(
|
||||||
|
|
||||||
assert "usage" in completion
|
assert "usage" in completion
|
||||||
assert isinstance(function_call, str)
|
assert isinstance(function_call, str)
|
||||||
assert stream is False # TODO: support stream mode
|
assert stream is False # TODO: support stream mode
|
||||||
|
|
||||||
return llama_types.CreateChatCompletionResponse(
|
return llama_types.CreateChatCompletionResponse(
|
||||||
id="chat" + completion["id"],
|
id="chat" + completion["id"],
|
||||||
|
@ -759,7 +781,9 @@ class Llava15ChatHandler:
|
||||||
self._llava_cpp = llava_cpp
|
self._llava_cpp = llava_cpp
|
||||||
self.clip_model_path = clip_model_path
|
self.clip_model_path = clip_model_path
|
||||||
|
|
||||||
self.clip_ctx = self._llava_cpp.clip_model_load(self.clip_model_path.encode(), 0)
|
self.clip_ctx = self._llava_cpp.clip_model_load(
|
||||||
|
self.clip_model_path.encode(), 0
|
||||||
|
)
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
if self.clip_ctx is not None:
|
if self.clip_ctx is not None:
|
||||||
|
@ -805,12 +829,21 @@ class Llava15ChatHandler:
|
||||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||||
grammar: Optional[llama.LlamaGrammar] = None,
|
grammar: Optional[llama.LlamaGrammar] = None,
|
||||||
**kwargs, # type: ignore
|
**kwargs, # type: ignore
|
||||||
) -> Union[llama_types.CreateChatCompletionResponse, Iterator[llama_types.CreateChatCompletionStreamResponse]]:
|
) -> Union[
|
||||||
assert llama.context_params.logits_all is True # BUG: logits_all=True is required for llava
|
llama_types.CreateChatCompletionResponse,
|
||||||
|
Iterator[llama_types.CreateChatCompletionStreamResponse],
|
||||||
|
]:
|
||||||
|
assert (
|
||||||
|
llama.context_params.logits_all is True
|
||||||
|
) # BUG: logits_all=True is required for llava
|
||||||
assert self.clip_ctx is not None
|
assert self.clip_ctx is not None
|
||||||
system_prompt = _get_system_message(messages)
|
system_prompt = _get_system_message(messages)
|
||||||
system_prompt = system_prompt if system_prompt != "" else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
system_prompt = (
|
||||||
system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
system_prompt
|
||||||
|
if system_prompt != ""
|
||||||
|
else "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||||
|
)
|
||||||
|
system_prompt = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
||||||
user_role = "\nUSER:"
|
user_role = "\nUSER:"
|
||||||
assistant_role = "\nASSISTANT:"
|
assistant_role = "\nASSISTANT:"
|
||||||
llama.reset()
|
llama.reset()
|
||||||
|
@ -818,51 +851,86 @@ class Llava15ChatHandler:
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message["role"] == "user" and message["content"] is not None:
|
if message["role"] == "user" and message["content"] is not None:
|
||||||
if isinstance(message["content"], str):
|
if isinstance(message["content"], str):
|
||||||
llama.eval(llama.tokenize(f"{user_role} {message['content']}".encode("utf8"), add_bos=False))
|
llama.eval(
|
||||||
|
llama.tokenize(
|
||||||
|
f"{user_role} {message['content']}".encode("utf8"),
|
||||||
|
add_bos=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert isinstance(message["content"], list)
|
assert isinstance(message["content"], list)
|
||||||
llama.eval(llama.tokenize(f"{user_role} ".encode("utf8"), add_bos=False))
|
llama.eval(
|
||||||
|
llama.tokenize(f"{user_role} ".encode("utf8"), add_bos=False)
|
||||||
|
)
|
||||||
for content in message["content"]:
|
for content in message["content"]:
|
||||||
if content["type"] == "text":
|
if content["type"] == "text":
|
||||||
llama.eval(llama.tokenize(f"{content['text']}".encode("utf8"), add_bos=False))
|
llama.eval(
|
||||||
|
llama.tokenize(
|
||||||
|
f"{content['text']}".encode("utf8"), add_bos=False
|
||||||
|
)
|
||||||
|
)
|
||||||
if content["type"] == "image_url":
|
if content["type"] == "image_url":
|
||||||
image_bytes = self.load_image(content["image_url"]["url"]) if isinstance(content["image_url"], dict) else self.load_image(content["image_url"])
|
image_bytes = (
|
||||||
|
self.load_image(content["image_url"]["url"])
|
||||||
|
if isinstance(content["image_url"], dict)
|
||||||
|
else self.load_image(content["image_url"])
|
||||||
|
)
|
||||||
import array
|
import array
|
||||||
data_array = array.array('B', image_bytes)
|
|
||||||
c_ubyte_ptr = (ctypes.c_ubyte * len(data_array)).from_buffer(data_array)
|
data_array = array.array("B", image_bytes)
|
||||||
embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=llama.context_params.n_threads, image_bytes=c_ubyte_ptr, image_bytes_length=len(image_bytes))
|
c_ubyte_ptr = (
|
||||||
|
ctypes.c_ubyte * len(data_array)
|
||||||
|
).from_buffer(data_array)
|
||||||
|
embed = self._llava_cpp.llava_image_embed_make_with_bytes(
|
||||||
|
ctx_clip=self.clip_ctx,
|
||||||
|
n_threads=llama.context_params.n_threads,
|
||||||
|
image_bytes=c_ubyte_ptr,
|
||||||
|
image_bytes_length=len(image_bytes),
|
||||||
|
)
|
||||||
# image_bytes_p = (ctypes.c_uint8 * len(image_bytes)).from_buffer_copy(image_bytes)
|
# image_bytes_p = (ctypes.c_uint8 * len(image_bytes)).from_buffer_copy(image_bytes)
|
||||||
# embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=1, image_bytes=image_bytes_p, image_bytes_length=len(image_bytes))
|
# embed = self._llava_cpp.llava_image_embed_make_with_bytes(ctx_clip=self.clip_ctx, n_threads=1, image_bytes=image_bytes_p, image_bytes_length=len(image_bytes))
|
||||||
try:
|
try:
|
||||||
n_past = ctypes.c_int(llama.n_tokens)
|
n_past = ctypes.c_int(llama.n_tokens)
|
||||||
n_past_p = ctypes.pointer(n_past)
|
n_past_p = ctypes.pointer(n_past)
|
||||||
self._llava_cpp.llava_eval_image_embed(ctx_llama=llama.ctx, embed=embed, n_batch=llama.n_batch, n_past=n_past_p)
|
self._llava_cpp.llava_eval_image_embed(
|
||||||
|
ctx_llama=llama.ctx,
|
||||||
|
embed=embed,
|
||||||
|
n_batch=llama.n_batch,
|
||||||
|
n_past=n_past_p,
|
||||||
|
)
|
||||||
assert llama.n_ctx() >= n_past.value
|
assert llama.n_ctx() >= n_past.value
|
||||||
llama.n_tokens = n_past.value
|
llama.n_tokens = n_past.value
|
||||||
finally:
|
finally:
|
||||||
self._llava_cpp.llava_image_embed_free(embed)
|
self._llava_cpp.llava_image_embed_free(embed)
|
||||||
if message["role"] == "assistant" and message["content"] is not None:
|
if message["role"] == "assistant" and message["content"] is not None:
|
||||||
llama.eval(llama.tokenize(f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False))
|
llama.eval(
|
||||||
|
llama.tokenize(
|
||||||
|
f"ASSISTANT: {message['content']}".encode("utf8"), add_bos=False
|
||||||
|
)
|
||||||
|
)
|
||||||
llama.eval(llama.tokenize(f"{assistant_role}".encode("utf8"), add_bos=False))
|
llama.eval(llama.tokenize(f"{assistant_role}".encode("utf8"), add_bos=False))
|
||||||
|
|
||||||
prompt = llama._input_ids.tolist()
|
prompt = llama._input_ids.tolist()
|
||||||
|
|
||||||
return _convert_completion_to_chat(llama.create_completion(
|
return _convert_completion_to_chat(
|
||||||
prompt=prompt,
|
llama.create_completion(
|
||||||
temperature=temperature,
|
prompt=prompt,
|
||||||
top_p=top_p,
|
temperature=temperature,
|
||||||
top_k=top_k,
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
stream=stream,
|
||||||
|
stop=stop,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
repeat_penalty=repeat_penalty,
|
||||||
|
tfs_z=tfs_z,
|
||||||
|
mirostat_mode=mirostat_mode,
|
||||||
|
mirostat_tau=mirostat_tau,
|
||||||
|
mirostat_eta=mirostat_eta,
|
||||||
|
model=model,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
grammar=grammar,
|
||||||
|
),
|
||||||
stream=stream,
|
stream=stream,
|
||||||
stop=stop,
|
)
|
||||||
max_tokens=max_tokens,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
repeat_penalty=repeat_penalty,
|
|
||||||
tfs_z=tfs_z,
|
|
||||||
mirostat_mode=mirostat_mode,
|
|
||||||
mirostat_tau=mirostat_tau,
|
|
||||||
mirostat_eta=mirostat_eta,
|
|
||||||
model=model,
|
|
||||||
logits_processor=logits_processor,
|
|
||||||
grammar=grammar,
|
|
||||||
), stream=stream)
|
|
||||||
|
|
|
@ -152,6 +152,10 @@ class ChatCompletionFunctionCallOption(TypedDict):
|
||||||
name: str
|
name: str
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionRequestResponseFormat(TypedDict):
|
||||||
|
type: Literal["text", "json_object"]
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequestMessageContentPartText(TypedDict):
|
class ChatCompletionRequestMessageContentPartText(TypedDict):
|
||||||
type: Literal["text"]
|
type: Literal["text"]
|
||||||
text: str
|
text: str
|
||||||
|
@ -241,7 +245,7 @@ ChatCompletionRequestFunctionCall = Union[
|
||||||
Literal["none", "auto"], ChatCompletionRequestFunctionCallOption
|
Literal["none", "auto"], ChatCompletionRequestFunctionCallOption
|
||||||
]
|
]
|
||||||
|
|
||||||
ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific
|
ChatCompletionFunctionParameters = Dict[str, JsonType] # TODO: make this more specific
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionToolFunction(TypedDict):
|
class ChatCompletionToolFunction(TypedDict):
|
||||||
|
|
|
@ -792,6 +792,9 @@ class CreateChatCompletionRequest(BaseModel):
|
||||||
frequency_penalty: Optional[float] = frequency_penalty_field
|
frequency_penalty: Optional[float] = frequency_penalty_field
|
||||||
logit_bias: Optional[Dict[str, float]] = Field(None)
|
logit_bias: Optional[Dict[str, float]] = Field(None)
|
||||||
seed: Optional[int] = Field(None)
|
seed: Optional[int] = Field(None)
|
||||||
|
response_format: Optional[llama_cpp.ChatCompletionRequestResponseFormat] = Field(
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
# ignored or currently unsupported
|
# ignored or currently unsupported
|
||||||
model: Optional[str] = model_field
|
model: Optional[str] = model_field
|
||||||
|
|
Loading…
Reference in a new issue