feat: Integrate functionary v1.4 and v2 models + add custom tokenizer support to Llama class (#1078)
* convert functionary-v1 chat handler to use hf autotokenizer * add hf_tokenizer + inteegrate functionary-v1.4 prompt template * integrate functionary v2 prompt template * update readme * set up parallel function calling wip * set up parallel function calling * Update README.md * Update README.md * refactor tokenizers * include old functionary handler for backward compatibility * add hf_tokenizer_path in server ModelSettings * convert functionary-v1 chat handler to use hf autotokenizer * add hf_tokenizer + inteegrate functionary-v1.4 prompt template * integrate functionary v2 prompt template * update readme * set up parallel function calling wip * resolve merge conflict * Update README.md * Update README.md * refactor tokenizers * include old functionary handler for backward compatibility * add hf_tokenizer_path in server ModelSettings * Cleanup PR, fix breaking changes * Use hf_pretrained_model_name_or_path for tokenizer * fix hf tokenizer in streaming * update README * refactor offset mapping --------- Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
parent
34f31040f6
commit
901827013b
4 changed files with 525 additions and 34 deletions
19
README.md
19
README.md
|
@ -293,19 +293,16 @@ To constrain the response to a specific JSON Schema, you can use the `schema` pr
|
|||
|
||||
The high-level API also provides a simple interface for function calling.
|
||||
|
||||
Note that the only model that supports full function calling at this time is "functionary".
|
||||
The gguf-converted files for this model can be found here: [functionary-7b-v1](https://huggingface.co/abetlen/functionary-7b-v1-GGUF)
|
||||
The only set of models that supports full function calling at this time is [functionary](https://github.com/MeetKai/functionary). The various gguf-converted files for this set of models can be found [here](https://huggingface.co/meetkai). Functionary is able to intelligently call functions and also analyze any provided function outputs to generate coherent responses. All v2 models of functionary supports **parallel function calling**. You can provide either `functionary-v1` or `functionary-v2` for the `chat_format` when initializing the Llama class.
|
||||
|
||||
Note that due to discrepancies between llama.cpp and HuggingFace's tokenizers, it is required to provide HF Tokenizer for functionary. The `LlamaHFTokenizer` class can be initialized and passed into the Llama class. This will override the default llama.cpp tokenizer used in Llama class. The tokenizer files are already included in the respective HF repositories hosting the gguf files.
|
||||
|
||||
```python
|
||||
>>> from llama_cpp import Llama
|
||||
>>> llm = Llama(model_path="path/to/functionary/llama-model.gguf", chat_format="functionary")
|
||||
>>> from llama_cpp import Llama, LlamaHFTokenizer
|
||||
>>> tokenizer = LlamaHFTokenizer.from_pretrained("path/to/functionary/")
|
||||
>>> llm = Llama(model_path="path/to/functionary/llama-model.gguf", tokenizer=tokenizer, chat_format="functionary-v2")
|
||||
>>> llm.create_chat_completion(
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"
|
||||
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Extract Jason is 25 years old"
|
||||
|
@ -332,12 +329,12 @@ The gguf-converted files for this model can be found here: [functionary-7b-v1](h
|
|||
}
|
||||
}
|
||||
}],
|
||||
tool_choice=[{
|
||||
tool_choice={
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "UserDetail"
|
||||
}
|
||||
}]
|
||||
},
|
||||
)
|
||||
```
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import os
|
||||
import sys
|
||||
import abc
|
||||
import uuid
|
||||
import time
|
||||
import multiprocessing
|
||||
|
@ -14,11 +15,14 @@ from typing import (
|
|||
Iterator,
|
||||
Deque,
|
||||
Callable,
|
||||
Any,
|
||||
)
|
||||
from collections import deque
|
||||
|
||||
import ctypes
|
||||
|
||||
from llama_cpp.llama_types import List
|
||||
|
||||
from .llama_types import *
|
||||
from .llama_grammar import LlamaGrammar
|
||||
from .llama_cache import (
|
||||
|
@ -95,6 +99,8 @@ class Llama:
|
|||
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
|
||||
# Speculative Decoding
|
||||
draft_model: Optional[LlamaDraftModel] = None,
|
||||
# Tokenizer Override
|
||||
tokenizer: Optional[BaseLlamaTokenizer] = None,
|
||||
# Misc
|
||||
verbose: bool = True,
|
||||
# Extra Params
|
||||
|
@ -159,6 +165,7 @@ class Llama:
|
|||
chat_format: String specifying the chat format to use when calling create_chat_completion.
|
||||
chat_handler: Optional chat handler to use when calling create_chat_completion.
|
||||
draft_model: Optional draft model to use for speculative decoding.
|
||||
tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp.
|
||||
verbose: Print verbose output to stderr.
|
||||
|
||||
Raises:
|
||||
|
@ -235,6 +242,7 @@ class Llama:
|
|||
self.n_threads_batch = n_threads_batch or max(
|
||||
multiprocessing.cpu_count() // 2, 1
|
||||
)
|
||||
|
||||
# Context Params
|
||||
self.context_params = llama_cpp.llama_context_default_params()
|
||||
self.context_params.seed = seed
|
||||
|
@ -286,6 +294,10 @@ class Llama:
|
|||
self._model = _LlamaModel(
|
||||
path_model=self.model_path, params=self.model_params, verbose=self.verbose
|
||||
)
|
||||
|
||||
# Override tokenizer
|
||||
self.tokenizer_ = tokenizer or LlamaTokenizer(self)
|
||||
|
||||
# Set the default value for the context and correct the batch
|
||||
if n_ctx == 0:
|
||||
n_ctx = self._model.n_ctx_train()
|
||||
|
@ -431,18 +443,19 @@ class Llama:
|
|||
Returns:
|
||||
A list of tokens.
|
||||
"""
|
||||
return self._model.tokenize(text, add_bos, special)
|
||||
return self.tokenizer_.tokenize(text, add_bos, special)
|
||||
|
||||
def detokenize(self, tokens: List[int]) -> bytes:
|
||||
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
|
||||
"""Detokenize a list of tokens.
|
||||
|
||||
Args:
|
||||
tokens: The list of tokens to detokenize.
|
||||
prev_tokens: The list of previous tokens. Offset mapping will be performed if provided
|
||||
|
||||
Returns:
|
||||
The detokenized string.
|
||||
"""
|
||||
return self._model.detokenize(tokens)
|
||||
return self.tokenizer_.detokenize(tokens, prev_tokens)
|
||||
|
||||
def set_cache(self, cache: Optional[BaseLlamaCache]):
|
||||
"""Set the cache.
|
||||
|
@ -935,7 +948,8 @@ class Llama:
|
|||
|
||||
if stream:
|
||||
remaining_tokens = completion_tokens[returned_tokens:]
|
||||
remaining_text = self.detokenize(remaining_tokens)
|
||||
prev_tokens = completion_tokens[:returned_tokens]
|
||||
remaining_text = self.detokenize(completion_tokens, prev_tokens)
|
||||
remaining_length = len(remaining_text)
|
||||
|
||||
# We want to avoid yielding any characters from
|
||||
|
@ -957,13 +971,13 @@ class Llama:
|
|||
for token in remaining_tokens:
|
||||
if token == self.token_bos():
|
||||
continue
|
||||
token_end_position += len(self.detokenize([token]))
|
||||
token_end_position += len(remaining_text)
|
||||
# Check if stop sequence is in the token
|
||||
if token_end_position > (
|
||||
remaining_length - first_stop_position
|
||||
):
|
||||
break
|
||||
token_str = self.detokenize([token]).decode(
|
||||
token_str = remaining_text.decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
text_offset = len(prompt) + len(
|
||||
|
@ -988,11 +1002,7 @@ class Llama:
|
|||
}
|
||||
top_logprob.update({token_str: current_logprobs[int(token)]})
|
||||
logprobs_or_none = {
|
||||
"tokens": [
|
||||
self.detokenize([token]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
],
|
||||
"tokens": [token_str],
|
||||
"text_offset": [text_offset],
|
||||
"token_logprobs": [current_logprobs[int(token)]],
|
||||
"top_logprobs": [top_logprob],
|
||||
|
@ -1005,9 +1015,7 @@ class Llama:
|
|||
"model": model_name,
|
||||
"choices": [
|
||||
{
|
||||
"text": self.detokenize([token]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
),
|
||||
"text": token_str,
|
||||
"index": 0,
|
||||
"logprobs": logprobs_or_none,
|
||||
"finish_reason": None,
|
||||
|
@ -1019,7 +1027,7 @@ class Llama:
|
|||
decode_success = False
|
||||
for i in range(1, len(remaining_tokens) + 1):
|
||||
try:
|
||||
bs = self.detokenize(remaining_tokens[:i])
|
||||
bs = remaining_text
|
||||
ts = bs.decode("utf-8")
|
||||
decode_success = True
|
||||
break
|
||||
|
@ -1055,6 +1063,7 @@ class Llama:
|
|||
|
||||
if len(completion_tokens) >= max_tokens:
|
||||
text = self.detokenize(completion_tokens)
|
||||
|
||||
finish_reason = "length"
|
||||
break
|
||||
|
||||
|
@ -1693,8 +1702,8 @@ class Llama:
|
|||
"""Return the vocabulary size."""
|
||||
return self._model.n_vocab()
|
||||
|
||||
def tokenizer(self) -> "LlamaTokenizer":
|
||||
"""Return the tokenizer for this model."""
|
||||
def tokenizer(self) -> LlamaTokenizer:
|
||||
"""Return the llama tokenizer for this model."""
|
||||
return LlamaTokenizer(self)
|
||||
|
||||
def token_eos(self) -> int:
|
||||
|
@ -1738,23 +1747,71 @@ class Llama:
|
|||
return longest_prefix
|
||||
|
||||
|
||||
class LlamaTokenizer:
|
||||
class BaseLlamaTokenizer(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LlamaTokenizer(BaseLlamaTokenizer):
|
||||
def __init__(self, llama: Llama):
|
||||
self.llama = llama
|
||||
self._model = llama._model # type: ignore
|
||||
|
||||
def encode(self, text: str, add_bos: bool = True) -> List[int]:
|
||||
return self.llama.tokenize(
|
||||
text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=True
|
||||
def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
|
||||
return self._model.tokenize(text, add_bos=add_bos, special=special)
|
||||
|
||||
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
|
||||
if prev_tokens is not None:
|
||||
return self._model.detokenize(tokens[len(prev_tokens):])
|
||||
else:
|
||||
return self._model.detokenize(tokens)
|
||||
|
||||
def encode(self, text: str, add_bos: bool = True, special: bool = True) -> List[int]:
|
||||
return self.tokenize(
|
||||
text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special
|
||||
)
|
||||
|
||||
def decode(self, tokens: List[int]) -> str:
|
||||
return self.llama.detokenize(tokens).decode("utf-8", errors="ignore")
|
||||
return self.detokenize(tokens).decode("utf-8", errors="ignore")
|
||||
|
||||
@classmethod
|
||||
def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
|
||||
return cls(Llama(model_path=path, vocab_only=True))
|
||||
|
||||
|
||||
class LlamaHFTokenizer(BaseLlamaTokenizer):
|
||||
def __init__(self, hf_tokenizer: Any):
|
||||
self.hf_tokenizer = hf_tokenizer
|
||||
|
||||
def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
|
||||
return self.hf_tokenizer.encode(text.decode("utf-8", errors="ignore"), add_special_tokens=special)
|
||||
|
||||
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
|
||||
if prev_tokens is not None:
|
||||
text = self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
|
||||
prev_text = self.hf_tokenizer.decode(prev_tokens).encode("utf-8", errors="ignore")
|
||||
return text[len(prev_text):]
|
||||
else:
|
||||
return self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer":
|
||||
try:
|
||||
from transformers import AutoTokenizer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The `transformers` library is required to use the `HFTokenizer`."
|
||||
"You can install it with `pip install transformers`."
|
||||
)
|
||||
hf_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
|
||||
return cls(hf_tokenizer)
|
||||
|
||||
|
||||
class LlamaState:
|
||||
def __init__(
|
||||
self,
|
||||
|
|
|
@ -4,7 +4,9 @@ import os
|
|||
import json
|
||||
import ctypes
|
||||
import dataclasses
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
|
||||
import random
|
||||
import string
|
||||
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol
|
||||
|
||||
import jinja2
|
||||
|
||||
|
@ -1332,6 +1334,435 @@ def functionary_chat_handler(
|
|||
)
|
||||
|
||||
|
||||
@register_chat_completion_handler("functionary-v1")
|
||||
@register_chat_completion_handler("functionary-v2")
|
||||
def functionary_v1_v2_chat_handler(
|
||||
llama: llama.Llama,
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
|
||||
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
|
||||
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
|
||||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
min_p: float = 0.05,
|
||||
typical_p: float = 1.0,
|
||||
stream: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
repeat_penalty: float = 1.1,
|
||||
tfs_z: float = 1.0,
|
||||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
model: Optional[str] = None,
|
||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
**kwargs, # type: ignore
|
||||
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
|
||||
SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
|
||||
|
||||
tokenizer = llama.tokenizer_
|
||||
assert hasattr(tokenizer, "hf_tokenizer"), "Please provide a valid hf_tokenizer_path from https://huggingface.co/meetkai when initializing the Llama class"
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
if "<|START_OF_FUNCTION_CALL|>" in tokenizer.hf_tokenizer.additional_special_tokens:
|
||||
version = "v1"
|
||||
END_SYSTEM_TOKEN = "<|END_OF_SYSTEM|>"
|
||||
END_USER_TOKEN = "<|END_OF_USER|>"
|
||||
END_ASSISTANT_TOKEN = "<|END_OF_ASSISTANT|>"
|
||||
END_FUNCTION_RESULT_TOKEN = "<|END_OF_FUNCTION_RESULT|>"
|
||||
START_FUNCTION_CALL_TOKEN = "<|START_OF_FUNCTION_CALL|>"
|
||||
END_FUNCTION_CALL_TOKEN = "<|END_OF_FUNCTION_CALL|>"
|
||||
else:
|
||||
version = "v2"
|
||||
RECIPIENT_TOKEN = "<|recipient|>"
|
||||
FROM_TOKEN = "<|from|>"
|
||||
STOP_TOKEN = "<|stop|>"
|
||||
CONTENT_TOKEN = "<|content|>"
|
||||
|
||||
def generate_type_definition(
|
||||
param: Dict[str, llama_types.JsonType], indent_level: int, shared_defs
|
||||
) -> str:
|
||||
indent = " " * indent_level
|
||||
if "$ref" in param:
|
||||
# Reference to a shared definition
|
||||
ref_name = param["$ref"].split("/")[
|
||||
-1
|
||||
] # Extract the type name from the reference
|
||||
return ref_name
|
||||
elif param.get("type") == "array":
|
||||
items = param.get("items", {})
|
||||
item_type = generate_type_definition(items, indent_level + 1, shared_defs)
|
||||
return f"Array<{item_type}>"
|
||||
elif param.get("type") == "object":
|
||||
properties = param.get("properties", {})
|
||||
nested_schema = "{\n"
|
||||
for nested_param_name, nested_param in properties.items():
|
||||
nested_param_type = generate_type_definition(
|
||||
nested_param, indent_level + 1, shared_defs
|
||||
)
|
||||
nested_schema += (
|
||||
f"{indent} {nested_param_name}: {nested_param_type},\n"
|
||||
)
|
||||
nested_schema += indent + "}"
|
||||
return nested_schema
|
||||
elif "enum" in param:
|
||||
# Enum type
|
||||
return " | ".join([f'"{enum_value}"' for enum_value in param["enum"]])
|
||||
else:
|
||||
# Simple type
|
||||
return param.get("type", "any")
|
||||
|
||||
def generate_shared_definitions(shared_defs, indent_level: int) -> str:
|
||||
indent = " " * indent_level
|
||||
shared_definitions = ""
|
||||
for def_name, def_properties in shared_defs.items():
|
||||
shared_definitions += f"{indent}type {def_name} = "
|
||||
if def_properties.get("type") == "object":
|
||||
shared_definitions += generate_type_definition(
|
||||
def_properties, indent_level, shared_defs
|
||||
)
|
||||
elif "enum" in def_properties:
|
||||
# Enum type
|
||||
shared_definitions += " | ".join(
|
||||
[f'"{enum_value}"' for enum_value in def_properties["enum"]]
|
||||
)
|
||||
shared_definitions += ";\n"
|
||||
return shared_definitions
|
||||
|
||||
def generate_schema_from_functions(functions, namespace="functions") -> str:
|
||||
schema = (
|
||||
"// Supported function definitions that should be called when necessary.\n"
|
||||
)
|
||||
schema += f"namespace {namespace} {{\n\n"
|
||||
|
||||
# Generate shared definitions
|
||||
shared_definitions = {}
|
||||
for function in functions:
|
||||
parameters = function.get("parameters", {})
|
||||
shared_definitions.update(parameters.get("$defs", {}))
|
||||
|
||||
schema += generate_shared_definitions(shared_definitions, 1)
|
||||
|
||||
for function in functions:
|
||||
function_name = function["name"]
|
||||
description = function.get("description", "")
|
||||
parameters = function.get("parameters", {})
|
||||
required_params = parameters.get("required", [])
|
||||
|
||||
schema += f"// {description}\n"
|
||||
schema += f"type {function_name} = (_: {{\n"
|
||||
|
||||
for param_name, param in parameters.get("properties", {}).items():
|
||||
param_description = param.get("description", "")
|
||||
param_type = generate_type_definition(param, 2, shared_definitions)
|
||||
optional_indicator = "" if param_name in required_params else "?"
|
||||
schema += f"// {param_description}\n"
|
||||
schema += f"{param_name}{optional_indicator}: {param_type},\n"
|
||||
schema += "}) => any;\n\n"
|
||||
|
||||
schema += "}} // namespace {}".format(namespace)
|
||||
return schema
|
||||
|
||||
def prepare_messages_for_inference(
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
tokenizer: AutoTokenizer,
|
||||
version: Literal["v1", "v2"],
|
||||
functions: Optional[List[llama_types.ChatCompletionFunctions]] = None,
|
||||
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
|
||||
):
|
||||
all_messages: List[llama_types.ChatCompletionRequestMessage] = []
|
||||
if functions is not None:
|
||||
all_messages.append(
|
||||
llama_types.ChatCompletionRequestSystemMessage(
|
||||
role="system", content=generate_schema_from_functions(functions)
|
||||
)
|
||||
)
|
||||
elif tools is not None:
|
||||
all_messages.append(
|
||||
llama_types.ChatCompletionRequestSystemMessage(
|
||||
role="system",
|
||||
content=generate_schema_from_functions(
|
||||
[
|
||||
tool["function"]
|
||||
for tool in tools
|
||||
if tool["type"] == "function"
|
||||
]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
all_messages.append(
|
||||
llama_types.ChatCompletionRequestSystemMessage(
|
||||
role="system", content=SYSTEM_MESSAGE
|
||||
)
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
# Function call responses
|
||||
if message["role"] == "function" and "name" in message:
|
||||
message["name"] = f"functions.{message['name']}"
|
||||
# Function call requests by assistant
|
||||
if "function_call" in message:
|
||||
message["function_call"][
|
||||
"name"
|
||||
] = f"functions.{message['function_call']['name']}"
|
||||
all_messages.append(message)
|
||||
|
||||
if version == "v1":
|
||||
suffix = "assistant:\n"
|
||||
else:
|
||||
suffix = "<|from|>assistant\n<|recipient|>"
|
||||
|
||||
return tokenizer.hf_tokenizer.apply_chat_template(all_messages, tokenize=False) + suffix
|
||||
|
||||
if tools is not None:
|
||||
functions = [tool["function"] for tool in tools if tool["type"] == "function"]
|
||||
|
||||
if tool_choice is not None:
|
||||
function_call = (
|
||||
tool_choice if isinstance(tool_choice, str) else tool_choice["function"]
|
||||
)
|
||||
|
||||
prompt = prepare_messages_for_inference(messages, tokenizer, version, functions, tools)
|
||||
|
||||
# If no tools/functions are provided
|
||||
if function_call is None and (functions is None or len(functions) == 0):
|
||||
if version == "v1":
|
||||
stop = END_ASSISTANT_TOKEN
|
||||
else:
|
||||
stop = STOP_TOKEN
|
||||
prompt += "all\n<|content|>"
|
||||
|
||||
completion_or_completion_chunks = llama.create_completion(
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
repeat_penalty=repeat_penalty,
|
||||
tfs_z=tfs_z,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
model=model,
|
||||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
)
|
||||
return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore
|
||||
|
||||
assert stream is False # TODO: support stream mode
|
||||
|
||||
def get_grammar(function_call):
|
||||
function_body = None
|
||||
for function in functions or []:
|
||||
if function["name"] == function_call:
|
||||
function_body = function["parameters"]
|
||||
break
|
||||
for tool in tools or []:
|
||||
if tool["type"] == "function" and tool["function"]["name"] == function_call:
|
||||
function_body = tool["function"]["parameters"]
|
||||
break
|
||||
|
||||
try:
|
||||
with suppress_stdout_stderr(disable=llama.verbose):
|
||||
grammar_text = llama_grammar.json_schema_to_gbnf(
|
||||
json.dumps(function_body)
|
||||
)
|
||||
grammar = llama_grammar.LlamaGrammar.from_string(
|
||||
llama_grammar.json_schema_to_gbnf(json.dumps(function_body))
|
||||
)
|
||||
print(grammar_text)
|
||||
except Exception as e:
|
||||
if llama.verbose:
|
||||
print(
|
||||
"Failed to parse function body as JSON schema, falling back to default grammar"
|
||||
)
|
||||
print(e)
|
||||
with suppress_stdout_stderr(disable=llama.verbose):
|
||||
grammar = llama_grammar.LlamaGrammar.from_string(
|
||||
llama_grammar.JSON_GBNF
|
||||
)
|
||||
|
||||
return grammar
|
||||
|
||||
def create_completion(stop):
|
||||
completion: llama_types.Completion = llama.create_completion(
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
repeat_penalty=repeat_penalty,
|
||||
tfs_z=tfs_z,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
model=model,
|
||||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
)
|
||||
|
||||
return completion
|
||||
|
||||
function_calls, function_bodies = [], []
|
||||
|
||||
if version == "v1":
|
||||
# If no or "auto" tool_choice/function_call
|
||||
if function_call is None or (
|
||||
isinstance(function_call, str) and function_call == "auto"
|
||||
):
|
||||
stops = ["\n", END_ASSISTANT_TOKEN]
|
||||
# If tool_choice/function_call is "none"
|
||||
elif isinstance(function_call, str) and function_call == "none":
|
||||
prompt = prepare_messages_for_inference(messages, tokenizer, version, [], [])
|
||||
stops = END_ASSISTANT_TOKEN
|
||||
# If tool_choice/function_call is provided
|
||||
elif isinstance(function_call, dict):
|
||||
prompt += f"{START_FUNCTION_CALL_TOKEN}{function_call['name']}:\n"
|
||||
stops = END_FUNCTION_CALL_TOKEN
|
||||
function_call = function_call["name"]
|
||||
function_calls.append(function_call)
|
||||
grammar = get_grammar(function_call)
|
||||
else:
|
||||
prompt = prompt
|
||||
stops = ["\n", END_ASSISTANT_TOKEN]
|
||||
|
||||
completion = create_completion(stop=stops)
|
||||
completion_text = completion["choices"][0]["text"]
|
||||
|
||||
# If the generation does not involve a function call
|
||||
if START_FUNCTION_CALL_TOKEN not in prompt and START_FUNCTION_CALL_TOKEN not in completion_text:
|
||||
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
|
||||
# If the generation involves a function call in completion, generate the parameters
|
||||
elif START_FUNCTION_CALL_TOKEN not in prompt and START_FUNCTION_CALL_TOKEN in completion_text:
|
||||
prompt += completion_text.replace(f"{START_FUNCTION_CALL_TOKEN} ", START_FUNCTION_CALL_TOKEN) + "\n"
|
||||
function_calls.append(completion_text.split(START_FUNCTION_CALL_TOKEN)[-1][:-1].strip())
|
||||
grammar = get_grammar(function_calls[-1])
|
||||
completion = create_completion(stop=END_FUNCTION_CALL_TOKEN)
|
||||
function_bodies.append(completion["choices"][0]["text"].strip())
|
||||
# If the prompt involves a function call, just append generated parameters to function_bodies
|
||||
else:
|
||||
function_bodies.append(completion_text.strip())
|
||||
else:
|
||||
# Loop until all parallel function calls are generated
|
||||
while True:
|
||||
# If no or "auto" tool_choice/function_call
|
||||
if function_call is None or (
|
||||
isinstance(function_call, str) and function_call == "auto"
|
||||
):
|
||||
grammar = None
|
||||
stops = CONTENT_TOKEN
|
||||
# If tool_choice/function_call is "none"
|
||||
elif isinstance(function_call, str) and function_call == "none":
|
||||
prompt = prepare_messages_for_inference(messages, tokenizer, version, [], []) + "all\n<|content|>"
|
||||
stops = STOP_TOKEN
|
||||
# If tool_choice/function_call is provided
|
||||
elif isinstance(function_call, dict):
|
||||
prompt += f"{function_call['name']}\n{CONTENT_TOKEN}"
|
||||
stops = STOP_TOKEN
|
||||
function_call = function_call["name"]
|
||||
function_calls.append(function_call)
|
||||
grammar = get_grammar(function_call)
|
||||
else:
|
||||
prompt = prompt
|
||||
stops = STOP_TOKEN
|
||||
|
||||
completion = create_completion(stop=stops)
|
||||
completion_text = completion["choices"][0]["text"]
|
||||
|
||||
# If the generation does not involve a function call
|
||||
if prompt.endswith("all\n<|content|>") and not completion_text.startswith("all"):
|
||||
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
|
||||
# Generate model response if the model decides not to call any function
|
||||
elif (prompt.endswith(RECIPIENT_TOKEN) and completion_text.startswith("all")):
|
||||
prompt += completion_text + CONTENT_TOKEN
|
||||
completion = create_completion(stop=STOP_TOKEN)
|
||||
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
|
||||
# Generate parameters if model decides to call a function
|
||||
elif prompt.endswith(RECIPIENT_TOKEN):
|
||||
function_calls.append(completion_text[:-1])
|
||||
grammar = get_grammar(function_calls[-1])
|
||||
completion = create_completion(stop=[STOP_TOKEN, "\n"])
|
||||
function_bodies.append(completion["choices"][0]["text"].strip())
|
||||
prompt += f"{function_calls[-1]}\n{CONTENT_TOKEN}{function_bodies[-1]}"
|
||||
grammar = None
|
||||
|
||||
# Try to generate the beginning of next turn
|
||||
# If empty completion, break from loop
|
||||
next_turn_completion_text = create_completion(
|
||||
stop=[STOP_TOKEN, RECIPIENT_TOKEN]
|
||||
)["choices"][0]["text"]
|
||||
if len(next_turn_completion_text) > 0:
|
||||
prompt += f"\n{FROM_TOKEN}assistant\n{RECIPIENT_TOKEN}"
|
||||
else:
|
||||
break
|
||||
# Break from loop if tool_choice/function_call is provided as a dict
|
||||
else:
|
||||
function_bodies.append(completion_text.strip())
|
||||
break
|
||||
|
||||
assert "usage" in completion
|
||||
assert len(function_calls) > 0
|
||||
assert len(function_calls) == len(function_bodies)
|
||||
|
||||
tool_calls = []
|
||||
for function_call, function_body in zip(function_calls, function_bodies):
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": "call_" + "".join(
|
||||
[random.choice(string.ascii_letters + string.digits) for _ in range(24)]
|
||||
),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_call,
|
||||
"arguments": function_body,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# TODO: support stream mode
|
||||
return llama_types.CreateChatCompletionResponse(
|
||||
id="chat" + completion["id"],
|
||||
object="chat.completion",
|
||||
created=completion["created"],
|
||||
model=completion["model"],
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"function_call": {
|
||||
"name": tool_calls[0]["function"]["name"],
|
||||
"arguments": tool_calls[0]["function"]["arguments"],
|
||||
},
|
||||
"tool_calls": tool_calls,
|
||||
},
|
||||
"finish_reason": "tool_calls",
|
||||
}
|
||||
],
|
||||
usage=completion["usage"],
|
||||
)
|
||||
|
||||
|
||||
class Llava15ChatHandler:
|
||||
_clip_free = None
|
||||
|
||||
|
|
|
@ -93,6 +93,10 @@ class LlamaProxy:
|
|||
)
|
||||
)
|
||||
|
||||
tokenizer: Optional[llama_cpp.BaseLlamaTokenizer] = None
|
||||
if settings.hf_pretrained_model_name_or_path is not None:
|
||||
tokenizer = llama_cpp.LlamaHFTokenizer.from_pretrained(settings.hf_pretrained_model_name_or_path)
|
||||
|
||||
draft_model = None
|
||||
if settings.draft_model is not None:
|
||||
draft_model = llama_speculative.LlamaPromptLookupDecoding(
|
||||
|
@ -156,6 +160,8 @@ class LlamaProxy:
|
|||
chat_handler=chat_handler,
|
||||
# Speculative Decoding
|
||||
draft_model=draft_model,
|
||||
# Tokenizer
|
||||
tokenizer=tokenizer,
|
||||
# Misc
|
||||
verbose=settings.verbose,
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue