This commit is contained in:
Andrei Betlen 2024-02-08 23:34:45 -05:00
commit 5b4ad6c80b
4 changed files with 525 additions and 34 deletions

View file

@ -293,19 +293,16 @@ To constrain the response to a specific JSON Schema, you can use the `schema` pr
The high-level API also provides a simple interface for function calling. The high-level API also provides a simple interface for function calling.
Note that the only model that supports full function calling at this time is "functionary". The only set of models that supports full function calling at this time is [functionary](https://github.com/MeetKai/functionary). The various gguf-converted files for this set of models can be found [here](https://huggingface.co/meetkai). Functionary is able to intelligently call functions and also analyze any provided function outputs to generate coherent responses. All v2 models of functionary supports **parallel function calling**. You can provide either `functionary-v1` or `functionary-v2` for the `chat_format` when initializing the Llama class.
The gguf-converted files for this model can be found here: [functionary-7b-v1](https://huggingface.co/abetlen/functionary-7b-v1-GGUF)
Note that due to discrepancies between llama.cpp and HuggingFace's tokenizers, it is required to provide HF Tokenizer for functionary. The `LlamaHFTokenizer` class can be initialized and passed into the Llama class. This will override the default llama.cpp tokenizer used in Llama class. The tokenizer files are already included in the respective HF repositories hosting the gguf files.
```python ```python
>>> from llama_cpp import Llama >>> from llama_cpp import Llama, LlamaHFTokenizer
>>> llm = Llama(model_path="path/to/functionary/llama-model.gguf", chat_format="functionary") >>> tokenizer = LlamaHFTokenizer.from_pretrained("path/to/functionary/")
>>> llm = Llama(model_path="path/to/functionary/llama-model.gguf", tokenizer=tokenizer, chat_format="functionary-v2")
>>> llm.create_chat_completion( >>> llm.create_chat_completion(
messages = [ messages = [
{
"role": "system",
"content": "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"
},
{ {
"role": "user", "role": "user",
"content": "Extract Jason is 25 years old" "content": "Extract Jason is 25 years old"
@ -332,12 +329,12 @@ The gguf-converted files for this model can be found here: [functionary-7b-v1](h
} }
} }
}], }],
tool_choice=[{ tool_choice={
"type": "function", "type": "function",
"function": { "function": {
"name": "UserDetail" "name": "UserDetail"
} }
}] },
) )
``` ```

View file

@ -2,6 +2,7 @@ from __future__ import annotations
import os import os
import sys import sys
import abc
import uuid import uuid
import time import time
import multiprocessing import multiprocessing
@ -14,11 +15,14 @@ from typing import (
Iterator, Iterator,
Deque, Deque,
Callable, Callable,
Any,
) )
from collections import deque from collections import deque
import ctypes import ctypes
from llama_cpp.llama_types import List
from .llama_types import * from .llama_types import *
from .llama_grammar import LlamaGrammar from .llama_grammar import LlamaGrammar
from .llama_cache import ( from .llama_cache import (
@ -95,6 +99,8 @@ class Llama:
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None, chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
# Speculative Decoding # Speculative Decoding
draft_model: Optional[LlamaDraftModel] = None, draft_model: Optional[LlamaDraftModel] = None,
# Tokenizer Override
tokenizer: Optional[BaseLlamaTokenizer] = None,
# Misc # Misc
verbose: bool = True, verbose: bool = True,
# Extra Params # Extra Params
@ -159,6 +165,7 @@ class Llama:
chat_format: String specifying the chat format to use when calling create_chat_completion. chat_format: String specifying the chat format to use when calling create_chat_completion.
chat_handler: Optional chat handler to use when calling create_chat_completion. chat_handler: Optional chat handler to use when calling create_chat_completion.
draft_model: Optional draft model to use for speculative decoding. draft_model: Optional draft model to use for speculative decoding.
tokenizer: Optional tokenizer to override the default tokenizer from llama.cpp.
verbose: Print verbose output to stderr. verbose: Print verbose output to stderr.
Raises: Raises:
@ -235,6 +242,7 @@ class Llama:
self.n_threads_batch = n_threads_batch or max( self.n_threads_batch = n_threads_batch or max(
multiprocessing.cpu_count() // 2, 1 multiprocessing.cpu_count() // 2, 1
) )
# Context Params # Context Params
self.context_params = llama_cpp.llama_context_default_params() self.context_params = llama_cpp.llama_context_default_params()
self.context_params.seed = seed self.context_params.seed = seed
@ -286,6 +294,10 @@ class Llama:
self._model = _LlamaModel( self._model = _LlamaModel(
path_model=self.model_path, params=self.model_params, verbose=self.verbose path_model=self.model_path, params=self.model_params, verbose=self.verbose
) )
# Override tokenizer
self.tokenizer_ = tokenizer or LlamaTokenizer(self)
# Set the default value for the context and correct the batch # Set the default value for the context and correct the batch
if n_ctx == 0: if n_ctx == 0:
n_ctx = self._model.n_ctx_train() n_ctx = self._model.n_ctx_train()
@ -431,18 +443,19 @@ class Llama:
Returns: Returns:
A list of tokens. A list of tokens.
""" """
return self._model.tokenize(text, add_bos, special) return self.tokenizer_.tokenize(text, add_bos, special)
def detokenize(self, tokens: List[int]) -> bytes: def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
"""Detokenize a list of tokens. """Detokenize a list of tokens.
Args: Args:
tokens: The list of tokens to detokenize. tokens: The list of tokens to detokenize.
prev_tokens: The list of previous tokens. Offset mapping will be performed if provided
Returns: Returns:
The detokenized string. The detokenized string.
""" """
return self._model.detokenize(tokens) return self.tokenizer_.detokenize(tokens, prev_tokens)
def set_cache(self, cache: Optional[BaseLlamaCache]): def set_cache(self, cache: Optional[BaseLlamaCache]):
"""Set the cache. """Set the cache.
@ -935,7 +948,8 @@ class Llama:
if stream: if stream:
remaining_tokens = completion_tokens[returned_tokens:] remaining_tokens = completion_tokens[returned_tokens:]
remaining_text = self.detokenize(remaining_tokens) prev_tokens = completion_tokens[:returned_tokens]
remaining_text = self.detokenize(completion_tokens, prev_tokens)
remaining_length = len(remaining_text) remaining_length = len(remaining_text)
# We want to avoid yielding any characters from # We want to avoid yielding any characters from
@ -957,13 +971,13 @@ class Llama:
for token in remaining_tokens: for token in remaining_tokens:
if token == self.token_bos(): if token == self.token_bos():
continue continue
token_end_position += len(self.detokenize([token])) token_end_position += len(remaining_text)
# Check if stop sequence is in the token # Check if stop sequence is in the token
if token_end_position > ( if token_end_position > (
remaining_length - first_stop_position remaining_length - first_stop_position
): ):
break break
token_str = self.detokenize([token]).decode( token_str = remaining_text.decode(
"utf-8", errors="ignore" "utf-8", errors="ignore"
) )
text_offset = len(prompt) + len( text_offset = len(prompt) + len(
@ -988,11 +1002,7 @@ class Llama:
} }
top_logprob.update({token_str: current_logprobs[int(token)]}) top_logprob.update({token_str: current_logprobs[int(token)]})
logprobs_or_none = { logprobs_or_none = {
"tokens": [ "tokens": [token_str],
self.detokenize([token]).decode(
"utf-8", errors="ignore"
)
],
"text_offset": [text_offset], "text_offset": [text_offset],
"token_logprobs": [current_logprobs[int(token)]], "token_logprobs": [current_logprobs[int(token)]],
"top_logprobs": [top_logprob], "top_logprobs": [top_logprob],
@ -1005,9 +1015,7 @@ class Llama:
"model": model_name, "model": model_name,
"choices": [ "choices": [
{ {
"text": self.detokenize([token]).decode( "text": token_str,
"utf-8", errors="ignore"
),
"index": 0, "index": 0,
"logprobs": logprobs_or_none, "logprobs": logprobs_or_none,
"finish_reason": None, "finish_reason": None,
@ -1019,7 +1027,7 @@ class Llama:
decode_success = False decode_success = False
for i in range(1, len(remaining_tokens) + 1): for i in range(1, len(remaining_tokens) + 1):
try: try:
bs = self.detokenize(remaining_tokens[:i]) bs = remaining_text
ts = bs.decode("utf-8") ts = bs.decode("utf-8")
decode_success = True decode_success = True
break break
@ -1055,6 +1063,7 @@ class Llama:
if len(completion_tokens) >= max_tokens: if len(completion_tokens) >= max_tokens:
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
finish_reason = "length" finish_reason = "length"
break break
@ -1693,8 +1702,8 @@ class Llama:
"""Return the vocabulary size.""" """Return the vocabulary size."""
return self._model.n_vocab() return self._model.n_vocab()
def tokenizer(self) -> "LlamaTokenizer": def tokenizer(self) -> LlamaTokenizer:
"""Return the tokenizer for this model.""" """Return the llama tokenizer for this model."""
return LlamaTokenizer(self) return LlamaTokenizer(self)
def token_eos(self) -> int: def token_eos(self) -> int:
@ -1738,23 +1747,71 @@ class Llama:
return longest_prefix return longest_prefix
class LlamaTokenizer: class BaseLlamaTokenizer(abc.ABC):
@abc.abstractmethod
def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
raise NotImplementedError
@abc.abstractmethod
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
raise NotImplementedError
class LlamaTokenizer(BaseLlamaTokenizer):
def __init__(self, llama: Llama): def __init__(self, llama: Llama):
self.llama = llama self.llama = llama
self._model = llama._model # type: ignore
def encode(self, text: str, add_bos: bool = True) -> List[int]: def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
return self.llama.tokenize( return self._model.tokenize(text, add_bos=add_bos, special=special)
text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=True
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
if prev_tokens is not None:
return self._model.detokenize(tokens[len(prev_tokens):])
else:
return self._model.detokenize(tokens)
def encode(self, text: str, add_bos: bool = True, special: bool = True) -> List[int]:
return self.tokenize(
text.encode("utf-8", errors="ignore"), add_bos=add_bos, special=special
) )
def decode(self, tokens: List[int]) -> str: def decode(self, tokens: List[int]) -> str:
return self.llama.detokenize(tokens).decode("utf-8", errors="ignore") return self.detokenize(tokens).decode("utf-8", errors="ignore")
@classmethod @classmethod
def from_ggml_file(cls, path: str) -> "LlamaTokenizer": def from_ggml_file(cls, path: str) -> "LlamaTokenizer":
return cls(Llama(model_path=path, vocab_only=True)) return cls(Llama(model_path=path, vocab_only=True))
class LlamaHFTokenizer(BaseLlamaTokenizer):
def __init__(self, hf_tokenizer: Any):
self.hf_tokenizer = hf_tokenizer
def tokenize(self, text: bytes, add_bos: bool = True, special: bool = True) -> List[int]:
return self.hf_tokenizer.encode(text.decode("utf-8", errors="ignore"), add_special_tokens=special)
def detokenize(self, tokens: List[int], prev_tokens: Optional[List[int]] = None) -> bytes:
if prev_tokens is not None:
text = self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
prev_text = self.hf_tokenizer.decode(prev_tokens).encode("utf-8", errors="ignore")
return text[len(prev_text):]
else:
return self.hf_tokenizer.decode(tokens).encode("utf-8", errors="ignore")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str) -> "LlamaHFTokenizer":
try:
from transformers import AutoTokenizer
except ImportError:
raise ImportError(
"The `transformers` library is required to use the `HFTokenizer`."
"You can install it with `pip install transformers`."
)
hf_tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path)
return cls(hf_tokenizer)
class LlamaState: class LlamaState:
def __init__( def __init__(
self, self,

View file

@ -4,7 +4,9 @@ import os
import json import json
import ctypes import ctypes
import dataclasses import dataclasses
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol import random
import string
from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol
import jinja2 import jinja2
@ -1332,6 +1334,435 @@ def functionary_chat_handler(
) )
@register_chat_completion_handler("functionary-v1")
@register_chat_completion_handler("functionary-v2")
def functionary_v1_v2_chat_handler(
llama: llama.Llama,
messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
min_p: float = 0.05,
typical_p: float = 1.0,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
max_tokens: Optional[int] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
**kwargs, # type: ignore
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
tokenizer = llama.tokenizer_
assert hasattr(tokenizer, "hf_tokenizer"), "Please provide a valid hf_tokenizer_path from https://huggingface.co/meetkai when initializing the Llama class"
from transformers import AutoTokenizer
if "<|START_OF_FUNCTION_CALL|>" in tokenizer.hf_tokenizer.additional_special_tokens:
version = "v1"
END_SYSTEM_TOKEN = "<|END_OF_SYSTEM|>"
END_USER_TOKEN = "<|END_OF_USER|>"
END_ASSISTANT_TOKEN = "<|END_OF_ASSISTANT|>"
END_FUNCTION_RESULT_TOKEN = "<|END_OF_FUNCTION_RESULT|>"
START_FUNCTION_CALL_TOKEN = "<|START_OF_FUNCTION_CALL|>"
END_FUNCTION_CALL_TOKEN = "<|END_OF_FUNCTION_CALL|>"
else:
version = "v2"
RECIPIENT_TOKEN = "<|recipient|>"
FROM_TOKEN = "<|from|>"
STOP_TOKEN = "<|stop|>"
CONTENT_TOKEN = "<|content|>"
def generate_type_definition(
param: Dict[str, llama_types.JsonType], indent_level: int, shared_defs
) -> str:
indent = " " * indent_level
if "$ref" in param:
# Reference to a shared definition
ref_name = param["$ref"].split("/")[
-1
] # Extract the type name from the reference
return ref_name
elif param.get("type") == "array":
items = param.get("items", {})
item_type = generate_type_definition(items, indent_level + 1, shared_defs)
return f"Array<{item_type}>"
elif param.get("type") == "object":
properties = param.get("properties", {})
nested_schema = "{\n"
for nested_param_name, nested_param in properties.items():
nested_param_type = generate_type_definition(
nested_param, indent_level + 1, shared_defs
)
nested_schema += (
f"{indent} {nested_param_name}: {nested_param_type},\n"
)
nested_schema += indent + "}"
return nested_schema
elif "enum" in param:
# Enum type
return " | ".join([f'"{enum_value}"' for enum_value in param["enum"]])
else:
# Simple type
return param.get("type", "any")
def generate_shared_definitions(shared_defs, indent_level: int) -> str:
indent = " " * indent_level
shared_definitions = ""
for def_name, def_properties in shared_defs.items():
shared_definitions += f"{indent}type {def_name} = "
if def_properties.get("type") == "object":
shared_definitions += generate_type_definition(
def_properties, indent_level, shared_defs
)
elif "enum" in def_properties:
# Enum type
shared_definitions += " | ".join(
[f'"{enum_value}"' for enum_value in def_properties["enum"]]
)
shared_definitions += ";\n"
return shared_definitions
def generate_schema_from_functions(functions, namespace="functions") -> str:
schema = (
"// Supported function definitions that should be called when necessary.\n"
)
schema += f"namespace {namespace} {{\n\n"
# Generate shared definitions
shared_definitions = {}
for function in functions:
parameters = function.get("parameters", {})
shared_definitions.update(parameters.get("$defs", {}))
schema += generate_shared_definitions(shared_definitions, 1)
for function in functions:
function_name = function["name"]
description = function.get("description", "")
parameters = function.get("parameters", {})
required_params = parameters.get("required", [])
schema += f"// {description}\n"
schema += f"type {function_name} = (_: {{\n"
for param_name, param in parameters.get("properties", {}).items():
param_description = param.get("description", "")
param_type = generate_type_definition(param, 2, shared_definitions)
optional_indicator = "" if param_name in required_params else "?"
schema += f"// {param_description}\n"
schema += f"{param_name}{optional_indicator}: {param_type},\n"
schema += "}) => any;\n\n"
schema += "}} // namespace {}".format(namespace)
return schema
def prepare_messages_for_inference(
messages: List[llama_types.ChatCompletionRequestMessage],
tokenizer: AutoTokenizer,
version: Literal["v1", "v2"],
functions: Optional[List[llama_types.ChatCompletionFunctions]] = None,
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
):
all_messages: List[llama_types.ChatCompletionRequestMessage] = []
if functions is not None:
all_messages.append(
llama_types.ChatCompletionRequestSystemMessage(
role="system", content=generate_schema_from_functions(functions)
)
)
elif tools is not None:
all_messages.append(
llama_types.ChatCompletionRequestSystemMessage(
role="system",
content=generate_schema_from_functions(
[
tool["function"]
for tool in tools
if tool["type"] == "function"
]
),
)
)
all_messages.append(
llama_types.ChatCompletionRequestSystemMessage(
role="system", content=SYSTEM_MESSAGE
)
)
for message in messages:
# Function call responses
if message["role"] == "function" and "name" in message:
message["name"] = f"functions.{message['name']}"
# Function call requests by assistant
if "function_call" in message:
message["function_call"][
"name"
] = f"functions.{message['function_call']['name']}"
all_messages.append(message)
if version == "v1":
suffix = "assistant:\n"
else:
suffix = "<|from|>assistant\n<|recipient|>"
return tokenizer.hf_tokenizer.apply_chat_template(all_messages, tokenize=False) + suffix
if tools is not None:
functions = [tool["function"] for tool in tools if tool["type"] == "function"]
if tool_choice is not None:
function_call = (
tool_choice if isinstance(tool_choice, str) else tool_choice["function"]
)
prompt = prepare_messages_for_inference(messages, tokenizer, version, functions, tools)
# If no tools/functions are provided
if function_call is None and (functions is None or len(functions) == 0):
if version == "v1":
stop = END_ASSISTANT_TOKEN
else:
stop = STOP_TOKEN
prompt += "all\n<|content|>"
completion_or_completion_chunks = llama.create_completion(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
grammar=grammar,
)
return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore
assert stream is False # TODO: support stream mode
def get_grammar(function_call):
function_body = None
for function in functions or []:
if function["name"] == function_call:
function_body = function["parameters"]
break
for tool in tools or []:
if tool["type"] == "function" and tool["function"]["name"] == function_call:
function_body = tool["function"]["parameters"]
break
try:
with suppress_stdout_stderr(disable=llama.verbose):
grammar_text = llama_grammar.json_schema_to_gbnf(
json.dumps(function_body)
)
grammar = llama_grammar.LlamaGrammar.from_string(
llama_grammar.json_schema_to_gbnf(json.dumps(function_body))
)
print(grammar_text)
except Exception as e:
if llama.verbose:
print(
"Failed to parse function body as JSON schema, falling back to default grammar"
)
print(e)
with suppress_stdout_stderr(disable=llama.verbose):
grammar = llama_grammar.LlamaGrammar.from_string(
llama_grammar.JSON_GBNF
)
return grammar
def create_completion(stop):
completion: llama_types.Completion = llama.create_completion(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
min_p=min_p,
typical_p=typical_p,
stream=stream,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
grammar=grammar,
)
return completion
function_calls, function_bodies = [], []
if version == "v1":
# If no or "auto" tool_choice/function_call
if function_call is None or (
isinstance(function_call, str) and function_call == "auto"
):
stops = ["\n", END_ASSISTANT_TOKEN]
# If tool_choice/function_call is "none"
elif isinstance(function_call, str) and function_call == "none":
prompt = prepare_messages_for_inference(messages, tokenizer, version, [], [])
stops = END_ASSISTANT_TOKEN
# If tool_choice/function_call is provided
elif isinstance(function_call, dict):
prompt += f"{START_FUNCTION_CALL_TOKEN}{function_call['name']}:\n"
stops = END_FUNCTION_CALL_TOKEN
function_call = function_call["name"]
function_calls.append(function_call)
grammar = get_grammar(function_call)
else:
prompt = prompt
stops = ["\n", END_ASSISTANT_TOKEN]
completion = create_completion(stop=stops)
completion_text = completion["choices"][0]["text"]
# If the generation does not involve a function call
if START_FUNCTION_CALL_TOKEN not in prompt and START_FUNCTION_CALL_TOKEN not in completion_text:
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
# If the generation involves a function call in completion, generate the parameters
elif START_FUNCTION_CALL_TOKEN not in prompt and START_FUNCTION_CALL_TOKEN in completion_text:
prompt += completion_text.replace(f"{START_FUNCTION_CALL_TOKEN} ", START_FUNCTION_CALL_TOKEN) + "\n"
function_calls.append(completion_text.split(START_FUNCTION_CALL_TOKEN)[-1][:-1].strip())
grammar = get_grammar(function_calls[-1])
completion = create_completion(stop=END_FUNCTION_CALL_TOKEN)
function_bodies.append(completion["choices"][0]["text"].strip())
# If the prompt involves a function call, just append generated parameters to function_bodies
else:
function_bodies.append(completion_text.strip())
else:
# Loop until all parallel function calls are generated
while True:
# If no or "auto" tool_choice/function_call
if function_call is None or (
isinstance(function_call, str) and function_call == "auto"
):
grammar = None
stops = CONTENT_TOKEN
# If tool_choice/function_call is "none"
elif isinstance(function_call, str) and function_call == "none":
prompt = prepare_messages_for_inference(messages, tokenizer, version, [], []) + "all\n<|content|>"
stops = STOP_TOKEN
# If tool_choice/function_call is provided
elif isinstance(function_call, dict):
prompt += f"{function_call['name']}\n{CONTENT_TOKEN}"
stops = STOP_TOKEN
function_call = function_call["name"]
function_calls.append(function_call)
grammar = get_grammar(function_call)
else:
prompt = prompt
stops = STOP_TOKEN
completion = create_completion(stop=stops)
completion_text = completion["choices"][0]["text"]
# If the generation does not involve a function call
if prompt.endswith("all\n<|content|>") and not completion_text.startswith("all"):
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
# Generate model response if the model decides not to call any function
elif (prompt.endswith(RECIPIENT_TOKEN) and completion_text.startswith("all")):
prompt += completion_text + CONTENT_TOKEN
completion = create_completion(stop=STOP_TOKEN)
return _convert_completion_to_chat(completion, stream=stream) # type: ignore
# Generate parameters if model decides to call a function
elif prompt.endswith(RECIPIENT_TOKEN):
function_calls.append(completion_text[:-1])
grammar = get_grammar(function_calls[-1])
completion = create_completion(stop=[STOP_TOKEN, "\n"])
function_bodies.append(completion["choices"][0]["text"].strip())
prompt += f"{function_calls[-1]}\n{CONTENT_TOKEN}{function_bodies[-1]}"
grammar = None
# Try to generate the beginning of next turn
# If empty completion, break from loop
next_turn_completion_text = create_completion(
stop=[STOP_TOKEN, RECIPIENT_TOKEN]
)["choices"][0]["text"]
if len(next_turn_completion_text) > 0:
prompt += f"\n{FROM_TOKEN}assistant\n{RECIPIENT_TOKEN}"
else:
break
# Break from loop if tool_choice/function_call is provided as a dict
else:
function_bodies.append(completion_text.strip())
break
assert "usage" in completion
assert len(function_calls) > 0
assert len(function_calls) == len(function_bodies)
tool_calls = []
for function_call, function_body in zip(function_calls, function_bodies):
tool_calls.append(
{
"id": "call_" + "".join(
[random.choice(string.ascii_letters + string.digits) for _ in range(24)]
),
"type": "function",
"function": {
"name": function_call,
"arguments": function_body,
},
}
)
# TODO: support stream mode
return llama_types.CreateChatCompletionResponse(
id="chat" + completion["id"],
object="chat.completion",
created=completion["created"],
model=completion["model"],
choices=[
{
"index": 0,
"message": {
"role": "assistant",
"content": None,
"function_call": {
"name": tool_calls[0]["function"]["name"],
"arguments": tool_calls[0]["function"]["arguments"],
},
"tool_calls": tool_calls,
},
"finish_reason": "tool_calls",
}
],
usage=completion["usage"],
)
class Llava15ChatHandler: class Llava15ChatHandler:
_clip_free = None _clip_free = None

View file

@ -93,6 +93,10 @@ class LlamaProxy:
) )
) )
tokenizer: Optional[llama_cpp.BaseLlamaTokenizer] = None
if settings.hf_pretrained_model_name_or_path is not None:
tokenizer = llama_cpp.LlamaHFTokenizer.from_pretrained(settings.hf_pretrained_model_name_or_path)
draft_model = None draft_model = None
if settings.draft_model is not None: if settings.draft_model is not None:
draft_model = llama_speculative.LlamaPromptLookupDecoding( draft_model = llama_speculative.LlamaPromptLookupDecoding(
@ -156,6 +160,8 @@ class LlamaProxy:
chat_handler=chat_handler, chat_handler=chat_handler,
# Speculative Decoding # Speculative Decoding
draft_model=draft_model, draft_model=draft_model,
# Tokenizer
tokenizer=tokenizer,
# Misc # Misc
verbose=settings.verbose, verbose=settings.verbose,
) )