Add functionary support (#784)
* Add common grammars and json-schema-to-grammar utility function from llama.cpp * Pass functions to format function * Add basic functionary formatting * Add LlamaChatHandler for more complex chat use cases * Add function calling example notebook * Add support for regular chat completions alongside function calling
This commit is contained in:
parent
df31303a12
commit
3af7b21ff1
5 changed files with 936 additions and 99 deletions
225
examples/notebooks/Functions.ipynb
Normal file
225
examples/notebooks/Functions.ipynb
Normal file
|
@ -0,0 +1,225 @@
|
|||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{\n",
|
||||
" \"id\": \"chatcmpl-a6db1bbb-a128-4c28-88fe-30717ec806b2\",\n",
|
||||
" \"object\": \"chat.completion\",\n",
|
||||
" \"created\": 1698989577,\n",
|
||||
" \"model\": \"gpt-3.5-turbo-0613\",\n",
|
||||
" \"choices\": [\n",
|
||||
" {\n",
|
||||
" \"index\": 0,\n",
|
||||
" \"message\": {\n",
|
||||
" \"role\": \"assistant\",\n",
|
||||
" \"content\": \"The current weather in Boston is sunny with a temperature of 72 degrees\"\n",
|
||||
" },\n",
|
||||
" \"finish_reason\": \"length\"\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" \"usage\": {\n",
|
||||
" \"prompt_tokens\": 135,\n",
|
||||
" \"completion_tokens\": 16,\n",
|
||||
" \"total_tokens\": 151\n",
|
||||
" }\n",
|
||||
"}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import openai\n",
|
||||
"import json\n",
|
||||
"\n",
|
||||
"openai.api_key = \"sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\" # can be anything\n",
|
||||
"openai.api_base = \"http://100.64.159.73:8000/v1\"\n",
|
||||
"\n",
|
||||
"# Example dummy function hard coded to return the same weather\n",
|
||||
"# In production, this could be your backend API or an external API\n",
|
||||
"def get_current_weather(location, unit=\"fahrenheit\"):\n",
|
||||
" \"\"\"Get the current weather in a given location\"\"\"\n",
|
||||
" weather_info = {\n",
|
||||
" \"location\": location,\n",
|
||||
" \"temperature\": \"72\",\n",
|
||||
" \"unit\": unit,\n",
|
||||
" \"forecast\": [\"sunny\", \"windy\"],\n",
|
||||
" }\n",
|
||||
" return json.dumps(weather_info)\n",
|
||||
"\n",
|
||||
"def run_conversation():\n",
|
||||
" # Step 1: send the conversation and available functions to GPT\n",
|
||||
" messages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston?\"}]\n",
|
||||
" functions = [\n",
|
||||
" {\n",
|
||||
" \"name\": \"get_current_weather\",\n",
|
||||
" \"description\": \"Get the current weather in a given location\",\n",
|
||||
" \"parameters\": {\n",
|
||||
" \"type\": \"object\",\n",
|
||||
" \"properties\": {\n",
|
||||
" \"location\": {\n",
|
||||
" \"type\": \"string\",\n",
|
||||
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
|
||||
" },\n",
|
||||
" \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
|
||||
" },\n",
|
||||
" \"required\": [\"location\"],\n",
|
||||
" },\n",
|
||||
" }\n",
|
||||
" ]\n",
|
||||
" response = openai.ChatCompletion.create(\n",
|
||||
" model=\"gpt-3.5-turbo-0613\",\n",
|
||||
" messages=messages,\n",
|
||||
" functions=functions,\n",
|
||||
" function_call=\"auto\", # auto is default, but we'll be explicit\n",
|
||||
" )\n",
|
||||
" response_message = response[\"choices\"][0][\"message\"]\n",
|
||||
"\n",
|
||||
" # Step 2: check if GPT wanted to call a function\n",
|
||||
" if response_message.get(\"function_call\"):\n",
|
||||
" # Step 3: call the function\n",
|
||||
" # Note: the JSON response may not always be valid; be sure to handle errors\n",
|
||||
" available_functions = {\n",
|
||||
" \"get_current_weather\": get_current_weather,\n",
|
||||
" } # only one function in this example, but you can have multiple\n",
|
||||
" function_name = response_message[\"function_call\"][\"name\"]\n",
|
||||
" fuction_to_call = available_functions[function_name]\n",
|
||||
" function_args = json.loads(response_message[\"function_call\"][\"arguments\"])\n",
|
||||
" function_response = fuction_to_call(\n",
|
||||
" location=function_args.get(\"location\"),\n",
|
||||
" unit=function_args.get(\"unit\"),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Step 4: send the info on the function call and function response to GPT\n",
|
||||
" messages.append(response_message) # extend conversation with assistant's reply\n",
|
||||
" messages.append(\n",
|
||||
" {\n",
|
||||
" \"role\": \"function\",\n",
|
||||
" \"name\": function_name,\n",
|
||||
" \"content\": function_response,\n",
|
||||
" }\n",
|
||||
" ) # extend conversation with function response\n",
|
||||
" second_response = openai.ChatCompletion.create(\n",
|
||||
" model=\"gpt-3.5-turbo-0613\",\n",
|
||||
" messages=messages,\n",
|
||||
" ) # get a new response from GPT where it can see the function response\n",
|
||||
" return second_response\n",
|
||||
" else:\n",
|
||||
" print(response)\n",
|
||||
" print(\"No function\")\n",
|
||||
"\n",
|
||||
"print(run_conversation())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"name='Jason' age=25\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from pydantic import BaseModel\n",
|
||||
"from instructor import patch\n",
|
||||
"\n",
|
||||
"patch()\n",
|
||||
"\n",
|
||||
"class UserDetail(BaseModel):\n",
|
||||
" name: str\n",
|
||||
" age: int\n",
|
||||
"\n",
|
||||
"user: UserDetail = openai.ChatCompletion.create(\n",
|
||||
" model=\"gpt-3.5-turbo\",\n",
|
||||
" response_model=UserDetail,\n",
|
||||
" messages=[\n",
|
||||
" {\"role\": \"user\", \"content\": \"Extract Jason is 25 years old\"},\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"print(user)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"{\n",
|
||||
" \"id\": \"chatcmpl-59bcefad-9df5-4d6b-802c-5537b3e9044e\",\n",
|
||||
" \"object\": \"chat.completion\",\n",
|
||||
" \"created\": 1698989585,\n",
|
||||
" \"model\": \"gpt-3.5-turbo-0613\",\n",
|
||||
" \"choices\": [\n",
|
||||
" {\n",
|
||||
" \"index\": 0,\n",
|
||||
" \"message\": {\n",
|
||||
" \"role\": \"assistant\",\n",
|
||||
" \"content\": \"I don't have up-to-date information on the current weather conditions\"\n",
|
||||
" },\n",
|
||||
" \"finish_reason\": \"length\"\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" \"usage\": {\n",
|
||||
" \"prompt_tokens\": 62,\n",
|
||||
" \"completion_tokens\": 16,\n",
|
||||
" \"total_tokens\": 78\n",
|
||||
" }\n",
|
||||
"}\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"response = openai.ChatCompletion.create(\n",
|
||||
" model=\"gpt-3.5-turbo-0613\",\n",
|
||||
" messages=[\n",
|
||||
" {\"role\": \"user\", \"content\": \"What's the weather like in Boston?\"}\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"print(response)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "python-3.8.10",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.5+"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
|
@ -24,7 +24,7 @@ import ctypes
|
|||
from . import llama_cpp
|
||||
from .llama_types import *
|
||||
from .llama_grammar import LlamaGrammar
|
||||
from . import llama_chat_format
|
||||
import llama_cpp.llama_chat_format as llama_chat_format
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
@ -1539,78 +1539,6 @@ class Llama:
|
|||
grammar=grammar,
|
||||
)
|
||||
|
||||
def _convert_text_completion_to_chat(
|
||||
self, completion: Completion
|
||||
) -> ChatCompletion:
|
||||
return {
|
||||
"id": "chat" + completion["id"],
|
||||
"object": "chat.completion",
|
||||
"created": completion["created"],
|
||||
"model": completion["model"],
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": completion["choices"][0]["text"],
|
||||
},
|
||||
"finish_reason": completion["choices"][0]["finish_reason"],
|
||||
}
|
||||
],
|
||||
"usage": completion["usage"],
|
||||
}
|
||||
|
||||
def _convert_text_completion_chunks_to_chat(
|
||||
self,
|
||||
chunks: Iterator[CompletionChunk],
|
||||
) -> Iterator[ChatCompletionChunk]:
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i == 0:
|
||||
yield {
|
||||
"id": "chat" + chunk["id"],
|
||||
"model": chunk["model"],
|
||||
"created": chunk["created"],
|
||||
"object": "chat.completion.chunk",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
yield {
|
||||
"id": "chat" + chunk["id"],
|
||||
"model": chunk["model"],
|
||||
"created": chunk["created"],
|
||||
"object": "chat.completion.chunk",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": chunk["choices"][0]["text"],
|
||||
}
|
||||
if chunk["choices"][0]["finish_reason"] is None
|
||||
else {},
|
||||
"finish_reason": chunk["choices"][0]["finish_reason"],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
def _convert_completion_to_chat(
|
||||
self,
|
||||
completion_or_chunks: Union[Completion, Iterator[CompletionChunk]],
|
||||
stream: bool = False,
|
||||
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
||||
if stream:
|
||||
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
|
||||
return self._convert_text_completion_chunks_to_chat(chunks)
|
||||
else:
|
||||
completion: Completion = completion_or_chunks # type: ignore
|
||||
return self._convert_text_completion_to_chat(completion)
|
||||
|
||||
def create_chat_completion(
|
||||
self,
|
||||
messages: List[ChatCompletionRequestMessage],
|
||||
|
@ -1648,19 +1576,12 @@ class Llama:
|
|||
Returns:
|
||||
Generated chat completion or a stream of chat completion chunks.
|
||||
"""
|
||||
|
||||
format = llama_chat_format.get_chat_format(self.chat_format)
|
||||
result = format(
|
||||
handler = llama_chat_format.get_chat_completion_handler(self.chat_format)
|
||||
return handler(
|
||||
self,
|
||||
messages=messages,
|
||||
)
|
||||
prompt = result.prompt
|
||||
if result.stop is not None:
|
||||
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
|
||||
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
|
||||
stop = stop + rstop
|
||||
|
||||
completion_or_chunks = self.create_completion(
|
||||
prompt=prompt,
|
||||
functions=functions,
|
||||
function_call=function_call,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
|
@ -1678,7 +1599,6 @@ class Llama:
|
|||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
)
|
||||
return self._convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore
|
||||
|
||||
def _free_model(self, *, _lbatch_free=llama_cpp._lib.llama_batch_free, _lfree_model=llama_cpp._lib.llama_free_model, _free=llama_cpp._lib.llama_free):
|
||||
batch = getattr(self, 'batch', None)
|
||||
|
|
|
@ -1,6 +1,53 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, Protocol
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
|
||||
from . import llama_types
|
||||
from . import llama
|
||||
|
||||
|
||||
class LlamaChatCompletionHandler(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
llama: llama.Llama,
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||
function_call: Optional[
|
||||
Union[str, llama_types.ChatCompletionFunctionCall]
|
||||
] = None,
|
||||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
max_tokens: int = 256,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
repeat_penalty: float = 1.1,
|
||||
tfs_z: float = 1.0,
|
||||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
model: Optional[str] = None,
|
||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
|
||||
...
|
||||
|
||||
|
||||
CHAT_HANDLERS: Dict[str, LlamaChatCompletionHandler] = {}
|
||||
|
||||
|
||||
def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler:
|
||||
return CHAT_HANDLERS[name]
|
||||
|
||||
|
||||
def register_chat_completion_handler(name: str):
|
||||
def decorator(f: LlamaChatCompletionHandler):
|
||||
CHAT_HANDLERS[name] = f
|
||||
return f
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _get_system_message(
|
||||
|
@ -119,12 +166,150 @@ class ChatFormatter(Protocol):
|
|||
...
|
||||
|
||||
|
||||
class BasicChatHandler:
|
||||
def __init__(self, chat_format: str):
|
||||
self.chat_format = chat_format
|
||||
|
||||
|
||||
def _convert_text_completion_to_chat(
|
||||
completion: llama_types.Completion,
|
||||
) -> llama_types.ChatCompletion:
|
||||
return {
|
||||
"id": "chat" + completion["id"],
|
||||
"object": "chat.completion",
|
||||
"created": completion["created"],
|
||||
"model": completion["model"],
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": completion["choices"][0]["text"],
|
||||
},
|
||||
"finish_reason": completion["choices"][0]["finish_reason"],
|
||||
}
|
||||
],
|
||||
"usage": completion["usage"],
|
||||
}
|
||||
|
||||
|
||||
def _convert_text_completion_chunks_to_chat(
|
||||
chunks: Iterator[llama_types.CompletionChunk],
|
||||
) -> Iterator[llama_types.ChatCompletionChunk]:
|
||||
for i, chunk in enumerate(chunks):
|
||||
if i == 0:
|
||||
yield {
|
||||
"id": "chat" + chunk["id"],
|
||||
"model": chunk["model"],
|
||||
"created": chunk["created"],
|
||||
"object": "chat.completion.chunk",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"role": "assistant",
|
||||
},
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
yield {
|
||||
"id": "chat" + chunk["id"],
|
||||
"model": chunk["model"],
|
||||
"created": chunk["created"],
|
||||
"object": "chat.completion.chunk",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": chunk["choices"][0]["text"],
|
||||
}
|
||||
if chunk["choices"][0]["finish_reason"] is None
|
||||
else {},
|
||||
"finish_reason": chunk["choices"][0]["finish_reason"],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _convert_completion_to_chat(
|
||||
completion_or_chunks: Union[
|
||||
llama_types.Completion, Iterator[llama_types.CompletionChunk]
|
||||
],
|
||||
stream: bool = False,
|
||||
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
|
||||
if stream:
|
||||
chunks: Iterator[llama_types.CompletionChunk] = completion_or_chunks # type: ignore
|
||||
return _convert_text_completion_chunks_to_chat(chunks)
|
||||
else:
|
||||
completion: llama_types.Completion = completion_or_chunks # type: ignore
|
||||
return _convert_text_completion_to_chat(completion)
|
||||
|
||||
|
||||
_CHAT_FORMATS: Dict[str, ChatFormatter] = {}
|
||||
|
||||
|
||||
def register_chat_format(name: str):
|
||||
def decorator(f: ChatFormatter):
|
||||
_CHAT_FORMATS[name] = f
|
||||
def basic_create_chat_completion(
|
||||
llama: llama.Llama,
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||
function_call: Optional[
|
||||
Union[str, llama_types.ChatCompletionFunctionCall]
|
||||
] = None,
|
||||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
max_tokens: int = 256,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
repeat_penalty: float = 1.1,
|
||||
tfs_z: float = 1.0,
|
||||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
model: Optional[str] = None,
|
||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
) -> Union[
|
||||
llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]
|
||||
]:
|
||||
result = f(
|
||||
messages=messages,
|
||||
functions=functions,
|
||||
function_call=function_call,
|
||||
)
|
||||
prompt = result.prompt
|
||||
if result.stop is not None:
|
||||
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
|
||||
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
|
||||
stop = stop + rstop
|
||||
|
||||
completion_or_chunks = llama.create_completion(
|
||||
prompt=prompt,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
stream=stream,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
repeat_penalty=repeat_penalty,
|
||||
tfs_z=tfs_z,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
model=model,
|
||||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
)
|
||||
return _convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore
|
||||
|
||||
register_chat_completion_handler(name)(basic_create_chat_completion)
|
||||
return f
|
||||
|
||||
return decorator
|
||||
|
@ -320,3 +505,206 @@ def format_chatml(
|
|||
_messages.append((_roles["assistant"], None))
|
||||
_prompt = _format_chatml(system_message, _messages, _sep)
|
||||
return ChatFormatterResponse(prompt=_prompt)
|
||||
|
||||
|
||||
@register_chat_completion_handler("functionary")
|
||||
def functionary_chat_handler(
|
||||
llama: llama.Llama,
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||
function_call: Optional[Union[str, llama_types.ChatCompletionFunctionCall]] = None,
|
||||
temperature: float = 0.2,
|
||||
top_p: float = 0.95,
|
||||
top_k: int = 40,
|
||||
stream: bool = False,
|
||||
stop: Optional[Union[str, List[str]]] = [],
|
||||
max_tokens: int = 256,
|
||||
presence_penalty: float = 0.0,
|
||||
frequency_penalty: float = 0.0,
|
||||
repeat_penalty: float = 1.1,
|
||||
tfs_z: float = 1.0,
|
||||
mirostat_mode: int = 0,
|
||||
mirostat_tau: float = 5.0,
|
||||
mirostat_eta: float = 0.1,
|
||||
model: Optional[str] = None,
|
||||
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||
grammar: Optional[llama.LlamaGrammar] = None,
|
||||
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
|
||||
SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
|
||||
|
||||
def generate_schema_from_functions(
|
||||
functions: List[llama_types.ChatCompletionFunctions],
|
||||
namespace: str = "functions",
|
||||
):
|
||||
"""
|
||||
Convert functions schema to a schema that language models can understand.
|
||||
"""
|
||||
|
||||
schema = (
|
||||
"// Supported function definitions that should be called when necessary.\n"
|
||||
)
|
||||
schema += f"namespace {namespace} {{\n\n"
|
||||
|
||||
for function in functions:
|
||||
# Convert a Function object to dict, if necessary
|
||||
function_name = function["name"]
|
||||
description = function.get("description", "")
|
||||
schema += f"// {description}\n"
|
||||
schema += f"type {function_name}"
|
||||
|
||||
parameters = function.get("parameters", None)
|
||||
schema += " = (_: {\n"
|
||||
required_params = parameters.get("required", [])
|
||||
for param_name, param in parameters.get("properties", {}).items():
|
||||
# Param Description
|
||||
description = param.get("description")
|
||||
if description is not None:
|
||||
schema += f"// {description}\n"
|
||||
|
||||
# Param Name
|
||||
schema += f"{param_name}"
|
||||
if param_name not in required_params:
|
||||
schema += "?"
|
||||
|
||||
# Param Type
|
||||
param_type = param.get("type", "any")
|
||||
if param_type == "integer":
|
||||
param_type = "number"
|
||||
if "enum" in param:
|
||||
param_type = " | ".join([f'"{v}"' for v in param["enum"]])
|
||||
schema += f": {param_type},\n"
|
||||
|
||||
schema += "}) => any;\n\n"
|
||||
|
||||
schema += f"}} // namespace {namespace}"
|
||||
|
||||
return schema
|
||||
|
||||
def prepare_messages_for_inference(
|
||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||
functions: Optional[List[llama_types.ChatCompletionFunctions]] = None,
|
||||
):
|
||||
all_messages: List[llama_types.ChatCompletionRequestMessage] = []
|
||||
if functions is not None:
|
||||
all_messages.append(
|
||||
llama_types.ChatCompletionRequestMessage(
|
||||
role="system", content=generate_schema_from_functions(functions)
|
||||
)
|
||||
)
|
||||
|
||||
all_messages.append(
|
||||
llama_types.ChatCompletionRequestMessage(
|
||||
role="system", content=SYSTEM_MESSAGE
|
||||
)
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
# Function call responses
|
||||
if message["role"] == "function" and "name" in message:
|
||||
message["name"] = f"functions.{message['name']}"
|
||||
# Function call requests by assistant
|
||||
if "function_call" in message:
|
||||
message["function_call"][
|
||||
"name"
|
||||
] = f"functions.{message['function_call']['name']}"
|
||||
all_messages.append(message)
|
||||
|
||||
all_messages.append(
|
||||
llama_types.ChatCompletionRequestMessage(role="assistant", content=None)
|
||||
)
|
||||
|
||||
def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
|
||||
if msg["role"] == "system":
|
||||
return f"system:\n{msg['content']}\n"
|
||||
|
||||
elif msg["role"] == "function" and "name" in msg:
|
||||
return f"function name={msg['name']}:\n{msg['content']}\n"
|
||||
elif msg["role"] == "function" and "function_call" in msg:
|
||||
return f"function name={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n"
|
||||
elif msg["role"] == "user":
|
||||
if msg["content"] is None:
|
||||
return "user:\n</s>"
|
||||
else:
|
||||
return f"user:\n</s>{msg['content']}\n"
|
||||
elif msg["role"] == "assistant":
|
||||
if msg["content"] is not None and "function_call" in msg:
|
||||
return f"assistant:\n{msg['content']}\nassistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>"
|
||||
elif "function_call" in msg:
|
||||
return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>"
|
||||
elif msg["content"] is None:
|
||||
return "assistant"
|
||||
else:
|
||||
return f"assistant:\n{msg['content']}\n"
|
||||
else:
|
||||
raise ValueError(f"Unsupported role: {msg['role']}")
|
||||
|
||||
return "".join([message_to_str(msg) for msg in all_messages])
|
||||
|
||||
prompt = prepare_messages_for_inference(messages, functions)
|
||||
|
||||
if function_call is None and (functions is None or len(functions) == 0):
|
||||
completion_or_completion_chunks = llama.create_completion(
|
||||
prompt=prompt + ":\n",
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
stream=stream,
|
||||
stop=["user:", "</s>"],
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
repeat_penalty=repeat_penalty,
|
||||
tfs_z=tfs_z,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
model=model,
|
||||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
)
|
||||
return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore
|
||||
|
||||
if function_call is None or (
|
||||
isinstance(function_call, str) and function_call == "auto"
|
||||
):
|
||||
stop = "\n"
|
||||
completion: llama_types.Completion = llama.create_completion(
|
||||
prompt=prompt, stop=stop, stream=False
|
||||
) # type: ignore
|
||||
completion_text = completion["choices"][0]["text"]
|
||||
# strip " to=functions." and ending ":"
|
||||
function_call = completion_text[14:-1]
|
||||
new_prompt = prompt + completion_text + stop
|
||||
elif isinstance(function_call, str) and function_call != "none":
|
||||
new_prompt = prompt + f"assistant:\n"
|
||||
elif isinstance(function_call, dict):
|
||||
new_prompt = prompt + f"assistant to={function_call['name']}:\n"
|
||||
function_call = function_call["name"]
|
||||
else:
|
||||
new_prompt = prompt + f"assistant:\n"
|
||||
|
||||
completion: llama_types.Completion = llama.create_completion(
|
||||
prompt=new_prompt, stop=["user:", "</s>"], stream=False
|
||||
) # type: ignore
|
||||
|
||||
return llama_types.CreateChatCompletionResponse(
|
||||
id="chat" + completion["id"],
|
||||
object="chat.completion",
|
||||
created=completion["created"],
|
||||
model=completion["model"],
|
||||
choices=[
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "function",
|
||||
"content": None,
|
||||
"function_call": {
|
||||
"name": function_call,
|
||||
"arguments": completion["choices"][0]["text"],
|
||||
},
|
||||
},
|
||||
"finish_reason": "function_call",
|
||||
}
|
||||
],
|
||||
usage=completion["usage"],
|
||||
)
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""C++ implementation of the llama grammar parser."""
|
||||
"""Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp."""
|
||||
|
||||
# flake8: noqa
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
@ -1056,8 +1057,7 @@ def print_rule(
|
|||
# fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
||||
if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END:
|
||||
raise RuntimeError(
|
||||
"malformed rule, does not end with LLAMA_GRETYPE_END: "
|
||||
+ str(rule_id)
|
||||
"malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id)
|
||||
)
|
||||
print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ")
|
||||
# for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
||||
|
@ -1102,9 +1102,7 @@ def print_rule(
|
|||
for i, elem in enumerate(rule[:-1]):
|
||||
case = elem.type # type: llama_gretype
|
||||
if case is llama_gretype.LLAMA_GRETYPE_END:
|
||||
raise RuntimeError(
|
||||
"unexpected end of rule: " + str(rule_id) + "," + str(i)
|
||||
)
|
||||
raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i))
|
||||
elif case is llama_gretype.LLAMA_GRETYPE_ALT:
|
||||
print("| ", file=file, end="")
|
||||
elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF:
|
||||
|
@ -1186,3 +1184,308 @@ def print_grammar(file: TextIO, state: parse_state) -> None:
|
|||
f"{print_grammar.__name__}: error printing grammar: {err}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
|
||||
"""llama.cpp gbnf rules from vendor/llama.cpp/grammars"""
|
||||
|
||||
ARITHMETIC_GBNF = """\
|
||||
root ::= (expr "=" ws term "\n")+
|
||||
expr ::= term ([-+*/] term)*
|
||||
term ::= ident | num | "(" ws expr ")" ws
|
||||
ident ::= [a-z] [a-z0-9_]* ws
|
||||
num ::= [0-9]+ ws
|
||||
ws ::= [ \t\n]*
|
||||
"""
|
||||
|
||||
C_GBNF = """\
|
||||
root ::= (declaration)*
|
||||
|
||||
declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}"
|
||||
|
||||
dataType ::= "int" ws | "float" ws | "char" ws
|
||||
identifier ::= [a-zA-Z_] [a-zA-Z_0-9]*
|
||||
|
||||
parameter ::= dataType identifier
|
||||
|
||||
statement ::=
|
||||
( dataType identifier ws "=" ws expression ";" ) |
|
||||
( identifier ws "=" ws expression ";" ) |
|
||||
( identifier ws "(" argList? ")" ";" ) |
|
||||
( "return" ws expression ";" ) |
|
||||
( "while" "(" condition ")" "{" statement* "}" ) |
|
||||
( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) |
|
||||
( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) |
|
||||
( singleLineComment ) |
|
||||
( multiLineComment )
|
||||
|
||||
forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression
|
||||
forUpdate ::= identifier ws "=" ws expression
|
||||
|
||||
condition ::= expression relationOperator expression
|
||||
relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">")
|
||||
|
||||
expression ::= term (("+" | "-") term)*
|
||||
term ::= factor(("*" | "/") factor)*
|
||||
|
||||
factor ::= identifier | number | unaryTerm | funcCall | parenExpression
|
||||
unaryTerm ::= "-" factor
|
||||
funcCall ::= identifier "(" argList? ")"
|
||||
parenExpression ::= "(" ws expression ws ")"
|
||||
|
||||
argList ::= expression ("," ws expression)*
|
||||
|
||||
number ::= [0-9]+
|
||||
|
||||
singleLineComment ::= "//" [^\n]* "\n"
|
||||
multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/"
|
||||
|
||||
ws ::= ([ \t\n]+)
|
||||
"""
|
||||
|
||||
CHESS_GBNF = """\
|
||||
root ::= object
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
"""
|
||||
|
||||
JAPANESE_GBNF = """\
|
||||
root ::= object
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
"""
|
||||
|
||||
JSON_ARR_GBNF = """\
|
||||
# This is the same as json.gbnf but we restrict whitespaces at the end of the root array
|
||||
# Useful for generating JSON arrays
|
||||
|
||||
root ::= arr
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
|
||||
arr ::=
|
||||
"[\n" ws (
|
||||
value
|
||||
(",\n" ws value)*
|
||||
)? "]"
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
"""
|
||||
|
||||
|
||||
JSON_GBNF = """\
|
||||
root ::= object
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?"""
|
||||
|
||||
LIST_GBNF = """\
|
||||
root ::= item+
|
||||
|
||||
# Excludes various line break characters
|
||||
item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n"
|
||||
"""
|
||||
|
||||
"""llama.cpp json-schema to grammar converter from vendor/llama.cpp/examples/json-schema-to-grammar.py"""
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
# whitespace is constrained to a single space char to prevent model "running away" in
|
||||
# whitespace. Also maybe improves generation quality?
|
||||
SPACE_RULE = '" "?'
|
||||
|
||||
PRIMITIVE_RULES = {
|
||||
"boolean": '("true" | "false") space',
|
||||
"number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
|
||||
"integer": '("-"? ([0-9] | [1-9] [0-9]*)) space',
|
||||
"string": r""" "\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||
)* "\"" space """,
|
||||
"null": '"null" space',
|
||||
}
|
||||
|
||||
INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+")
|
||||
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
|
||||
GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'}
|
||||
|
||||
|
||||
class SchemaConverter:
|
||||
def __init__(self, prop_order):
|
||||
self._prop_order = prop_order
|
||||
self._rules = {"space": SPACE_RULE}
|
||||
|
||||
def _format_literal(self, literal):
|
||||
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
||||
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal)
|
||||
)
|
||||
return f'"{escaped}"'
|
||||
|
||||
def _add_rule(self, name, rule):
|
||||
esc_name = INVALID_RULE_CHARS_RE.sub("-", name)
|
||||
if esc_name not in self._rules or self._rules[esc_name] == rule:
|
||||
key = esc_name
|
||||
else:
|
||||
i = 0
|
||||
while f"{esc_name}{i}" in self._rules:
|
||||
i += 1
|
||||
key = f"{esc_name}{i}"
|
||||
self._rules[key] = rule
|
||||
return key
|
||||
|
||||
def visit(self, schema, name):
|
||||
schema_type = schema.get("type")
|
||||
rule_name = name or "root"
|
||||
|
||||
if "oneOf" in schema or "anyOf" in schema:
|
||||
rule = " | ".join(
|
||||
(
|
||||
self.visit(alt_schema, f'{name}{"-" if name else ""}{i}')
|
||||
for i, alt_schema in enumerate(
|
||||
schema.get("oneOf") or schema["anyOf"]
|
||||
)
|
||||
)
|
||||
)
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
elif "const" in schema:
|
||||
return self._add_rule(rule_name, self._format_literal(schema["const"]))
|
||||
|
||||
elif "enum" in schema:
|
||||
rule = " | ".join((self._format_literal(v) for v in schema["enum"]))
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
elif schema_type == "object" and "properties" in schema:
|
||||
# TODO: `required` keyword
|
||||
prop_order = self._prop_order
|
||||
prop_pairs = sorted(
|
||||
schema["properties"].items(),
|
||||
# sort by position in prop_order (if specified) then by key
|
||||
key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]),
|
||||
)
|
||||
|
||||
rule = '"{" space'
|
||||
for i, (prop_name, prop_schema) in enumerate(prop_pairs):
|
||||
prop_rule_name = self.visit(
|
||||
prop_schema, f'{name}{"-" if name else ""}{prop_name}'
|
||||
)
|
||||
if i > 0:
|
||||
rule += ' "," space'
|
||||
rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}'
|
||||
rule += ' "}" space'
|
||||
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
elif schema_type == "array" and "items" in schema:
|
||||
# TODO `prefixItems` keyword
|
||||
item_rule_name = self.visit(
|
||||
schema["items"], f'{name}{"-" if name else ""}item'
|
||||
)
|
||||
rule = (
|
||||
f'"[" space ({item_rule_name} ("," space {item_rule_name})*)? "]" space'
|
||||
)
|
||||
return self._add_rule(rule_name, rule)
|
||||
|
||||
else:
|
||||
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
|
||||
return self._add_rule(
|
||||
"root" if rule_name == "root" else schema_type,
|
||||
PRIMITIVE_RULES[schema_type],
|
||||
)
|
||||
|
||||
def format_grammar(self):
|
||||
return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items()))
|
||||
|
||||
|
||||
def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None):
|
||||
prop_order = prop_order or []
|
||||
schema = json.load(schema)
|
||||
prop_order = {name: idx for idx, name in enumerate(prop_order)}
|
||||
converter = SchemaConverter(prop_order)
|
||||
converter.visit(schema, "")
|
||||
return converter.format_grammar()
|
||||
|
|
|
@ -743,21 +743,22 @@ async def create_embedding(
|
|||
|
||||
|
||||
class ChatCompletionRequestMessage(BaseModel):
|
||||
role: Literal["system", "user", "assistant"] = Field(
|
||||
role: Literal["system", "user", "assistant", "function"] = Field(
|
||||
default="user", description="The role of the message."
|
||||
)
|
||||
content: str = Field(default="", description="The content of the message.")
|
||||
content: Optional[str] = Field(default="", description="The content of the message.")
|
||||
|
||||
from typing import Any
|
||||
|
||||
class CreateChatCompletionRequest(BaseModel):
|
||||
messages: List[ChatCompletionRequestMessage] = Field(
|
||||
messages: List[Any] = Field(
|
||||
default=[], description="A list of messages to generate completions for."
|
||||
)
|
||||
functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field(
|
||||
default=None,
|
||||
description="A list of functions to apply to the generated completions.",
|
||||
)
|
||||
function_call: Optional[Union[str, llama_cpp.ChatCompletionFunctionCall]] = Field(
|
||||
function_call: Optional[Union[Literal["auto", "none"], llama_cpp.ChatCompletionFunctionCallOption]] = Field(
|
||||
default=None,
|
||||
description="A function to apply to the generated completions.",
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue