Add functionary support (#784)

* Add common grammars and json-schema-to-grammar utility function from llama.cpp

* Pass functions to format function

* Add basic functionary formatting

* Add LlamaChatHandler for more complex chat use cases

* Add function calling example notebook

* Add support for regular chat completions alongside function calling
This commit is contained in:
Andrei 2023-11-03 02:12:14 -04:00 committed by GitHub
parent df31303a12
commit 3af7b21ff1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 936 additions and 99 deletions

View file

@ -0,0 +1,225 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"id\": \"chatcmpl-a6db1bbb-a128-4c28-88fe-30717ec806b2\",\n",
" \"object\": \"chat.completion\",\n",
" \"created\": 1698989577,\n",
" \"model\": \"gpt-3.5-turbo-0613\",\n",
" \"choices\": [\n",
" {\n",
" \"index\": 0,\n",
" \"message\": {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"The current weather in Boston is sunny with a temperature of 72 degrees\"\n",
" },\n",
" \"finish_reason\": \"length\"\n",
" }\n",
" ],\n",
" \"usage\": {\n",
" \"prompt_tokens\": 135,\n",
" \"completion_tokens\": 16,\n",
" \"total_tokens\": 151\n",
" }\n",
"}\n"
]
}
],
"source": [
"import openai\n",
"import json\n",
"\n",
"openai.api_key = \"sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\" # can be anything\n",
"openai.api_base = \"http://100.64.159.73:8000/v1\"\n",
"\n",
"# Example dummy function hard coded to return the same weather\n",
"# In production, this could be your backend API or an external API\n",
"def get_current_weather(location, unit=\"fahrenheit\"):\n",
" \"\"\"Get the current weather in a given location\"\"\"\n",
" weather_info = {\n",
" \"location\": location,\n",
" \"temperature\": \"72\",\n",
" \"unit\": unit,\n",
" \"forecast\": [\"sunny\", \"windy\"],\n",
" }\n",
" return json.dumps(weather_info)\n",
"\n",
"def run_conversation():\n",
" # Step 1: send the conversation and available functions to GPT\n",
" messages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston?\"}]\n",
" functions = [\n",
" {\n",
" \"name\": \"get_current_weather\",\n",
" \"description\": \"Get the current weather in a given location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
" },\n",
" \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
" },\n",
" \"required\": [\"location\"],\n",
" },\n",
" }\n",
" ]\n",
" response = openai.ChatCompletion.create(\n",
" model=\"gpt-3.5-turbo-0613\",\n",
" messages=messages,\n",
" functions=functions,\n",
" function_call=\"auto\", # auto is default, but we'll be explicit\n",
" )\n",
" response_message = response[\"choices\"][0][\"message\"]\n",
"\n",
" # Step 2: check if GPT wanted to call a function\n",
" if response_message.get(\"function_call\"):\n",
" # Step 3: call the function\n",
" # Note: the JSON response may not always be valid; be sure to handle errors\n",
" available_functions = {\n",
" \"get_current_weather\": get_current_weather,\n",
" } # only one function in this example, but you can have multiple\n",
" function_name = response_message[\"function_call\"][\"name\"]\n",
" fuction_to_call = available_functions[function_name]\n",
" function_args = json.loads(response_message[\"function_call\"][\"arguments\"])\n",
" function_response = fuction_to_call(\n",
" location=function_args.get(\"location\"),\n",
" unit=function_args.get(\"unit\"),\n",
" )\n",
"\n",
" # Step 4: send the info on the function call and function response to GPT\n",
" messages.append(response_message) # extend conversation with assistant's reply\n",
" messages.append(\n",
" {\n",
" \"role\": \"function\",\n",
" \"name\": function_name,\n",
" \"content\": function_response,\n",
" }\n",
" ) # extend conversation with function response\n",
" second_response = openai.ChatCompletion.create(\n",
" model=\"gpt-3.5-turbo-0613\",\n",
" messages=messages,\n",
" ) # get a new response from GPT where it can see the function response\n",
" return second_response\n",
" else:\n",
" print(response)\n",
" print(\"No function\")\n",
"\n",
"print(run_conversation())"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"name='Jason' age=25\n"
]
}
],
"source": [
"from pydantic import BaseModel\n",
"from instructor import patch\n",
"\n",
"patch()\n",
"\n",
"class UserDetail(BaseModel):\n",
" name: str\n",
" age: int\n",
"\n",
"user: UserDetail = openai.ChatCompletion.create(\n",
" model=\"gpt-3.5-turbo\",\n",
" response_model=UserDetail,\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": \"Extract Jason is 25 years old\"},\n",
" ]\n",
")\n",
"print(user)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"id\": \"chatcmpl-59bcefad-9df5-4d6b-802c-5537b3e9044e\",\n",
" \"object\": \"chat.completion\",\n",
" \"created\": 1698989585,\n",
" \"model\": \"gpt-3.5-turbo-0613\",\n",
" \"choices\": [\n",
" {\n",
" \"index\": 0,\n",
" \"message\": {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"I don't have up-to-date information on the current weather conditions\"\n",
" },\n",
" \"finish_reason\": \"length\"\n",
" }\n",
" ],\n",
" \"usage\": {\n",
" \"prompt_tokens\": 62,\n",
" \"completion_tokens\": 16,\n",
" \"total_tokens\": 78\n",
" }\n",
"}\n"
]
}
],
"source": [
"response = openai.ChatCompletion.create(\n",
" model=\"gpt-3.5-turbo-0613\",\n",
" messages=[\n",
" {\"role\": \"user\", \"content\": \"What's the weather like in Boston?\"}\n",
" ]\n",
")\n",
"print(response)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "python-3.8.10",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5+"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View file

@ -24,7 +24,7 @@ import ctypes
from . import llama_cpp from . import llama_cpp
from .llama_types import * from .llama_types import *
from .llama_grammar import LlamaGrammar from .llama_grammar import LlamaGrammar
from . import llama_chat_format import llama_cpp.llama_chat_format as llama_chat_format
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@ -428,7 +428,7 @@ class Llama:
if self.verbose: if self.verbose:
print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr) print(llama_cpp.llama_print_system_info().decode("utf-8"), file=sys.stderr)
self.chat_format = chat_format self.chat_format = chat_format
self._n_vocab = self.n_vocab() self._n_vocab = self.n_vocab()
@ -1539,78 +1539,6 @@ class Llama:
grammar=grammar, grammar=grammar,
) )
def _convert_text_completion_to_chat(
self, completion: Completion
) -> ChatCompletion:
return {
"id": "chat" + completion["id"],
"object": "chat.completion",
"created": completion["created"],
"model": completion["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": completion["choices"][0]["text"],
},
"finish_reason": completion["choices"][0]["finish_reason"],
}
],
"usage": completion["usage"],
}
def _convert_text_completion_chunks_to_chat(
self,
chunks: Iterator[CompletionChunk],
) -> Iterator[ChatCompletionChunk]:
for i, chunk in enumerate(chunks):
if i == 0:
yield {
"id": "chat" + chunk["id"],
"model": chunk["model"],
"created": chunk["created"],
"object": "chat.completion.chunk",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
},
"finish_reason": None,
}
],
}
yield {
"id": "chat" + chunk["id"],
"model": chunk["model"],
"created": chunk["created"],
"object": "chat.completion.chunk",
"choices": [
{
"index": 0,
"delta": {
"content": chunk["choices"][0]["text"],
}
if chunk["choices"][0]["finish_reason"] is None
else {},
"finish_reason": chunk["choices"][0]["finish_reason"],
}
],
}
def _convert_completion_to_chat(
self,
completion_or_chunks: Union[Completion, Iterator[CompletionChunk]],
stream: bool = False,
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
if stream:
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
return self._convert_text_completion_chunks_to_chat(chunks)
else:
completion: Completion = completion_or_chunks # type: ignore
return self._convert_text_completion_to_chat(completion)
def create_chat_completion( def create_chat_completion(
self, self,
messages: List[ChatCompletionRequestMessage], messages: List[ChatCompletionRequestMessage],
@ -1648,19 +1576,12 @@ class Llama:
Returns: Returns:
Generated chat completion or a stream of chat completion chunks. Generated chat completion or a stream of chat completion chunks.
""" """
handler = llama_chat_format.get_chat_completion_handler(self.chat_format)
format = llama_chat_format.get_chat_format(self.chat_format) return handler(
result = format( self,
messages=messages, messages=messages,
) functions=functions,
prompt = result.prompt function_call=function_call,
if result.stop is not None:
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
stop = stop + rstop
completion_or_chunks = self.create_completion(
prompt=prompt,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
@ -1678,7 +1599,6 @@ class Llama:
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar, grammar=grammar,
) )
return self._convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore
def _free_model(self, *, _lbatch_free=llama_cpp._lib.llama_batch_free, _lfree_model=llama_cpp._lib.llama_free_model, _free=llama_cpp._lib.llama_free): def _free_model(self, *, _lbatch_free=llama_cpp._lib.llama_batch_free, _lfree_model=llama_cpp._lib.llama_free_model, _free=llama_cpp._lib.llama_free):
batch = getattr(self, 'batch', None) batch = getattr(self, 'batch', None)

View file

@ -1,6 +1,53 @@
from __future__ import annotations
import dataclasses import dataclasses
from typing import Any, Dict, List, Optional, Tuple, Union, Protocol from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
from . import llama_types from . import llama_types
from . import llama
class LlamaChatCompletionHandler(Protocol):
def __call__(
self,
llama: llama.Llama,
messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
function_call: Optional[
Union[str, llama_types.ChatCompletionFunctionCall]
] = None,
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
max_tokens: int = 256,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
...
CHAT_HANDLERS: Dict[str, LlamaChatCompletionHandler] = {}
def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler:
return CHAT_HANDLERS[name]
def register_chat_completion_handler(name: str):
def decorator(f: LlamaChatCompletionHandler):
CHAT_HANDLERS[name] = f
return f
return decorator
def _get_system_message( def _get_system_message(
@ -119,12 +166,150 @@ class ChatFormatter(Protocol):
... ...
class BasicChatHandler:
def __init__(self, chat_format: str):
self.chat_format = chat_format
def _convert_text_completion_to_chat(
completion: llama_types.Completion,
) -> llama_types.ChatCompletion:
return {
"id": "chat" + completion["id"],
"object": "chat.completion",
"created": completion["created"],
"model": completion["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": completion["choices"][0]["text"],
},
"finish_reason": completion["choices"][0]["finish_reason"],
}
],
"usage": completion["usage"],
}
def _convert_text_completion_chunks_to_chat(
chunks: Iterator[llama_types.CompletionChunk],
) -> Iterator[llama_types.ChatCompletionChunk]:
for i, chunk in enumerate(chunks):
if i == 0:
yield {
"id": "chat" + chunk["id"],
"model": chunk["model"],
"created": chunk["created"],
"object": "chat.completion.chunk",
"choices": [
{
"index": 0,
"delta": {
"role": "assistant",
},
"finish_reason": None,
}
],
}
yield {
"id": "chat" + chunk["id"],
"model": chunk["model"],
"created": chunk["created"],
"object": "chat.completion.chunk",
"choices": [
{
"index": 0,
"delta": {
"content": chunk["choices"][0]["text"],
}
if chunk["choices"][0]["finish_reason"] is None
else {},
"finish_reason": chunk["choices"][0]["finish_reason"],
}
],
}
def _convert_completion_to_chat(
completion_or_chunks: Union[
llama_types.Completion, Iterator[llama_types.CompletionChunk]
],
stream: bool = False,
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
if stream:
chunks: Iterator[llama_types.CompletionChunk] = completion_or_chunks # type: ignore
return _convert_text_completion_chunks_to_chat(chunks)
else:
completion: llama_types.Completion = completion_or_chunks # type: ignore
return _convert_text_completion_to_chat(completion)
_CHAT_FORMATS: Dict[str, ChatFormatter] = {} _CHAT_FORMATS: Dict[str, ChatFormatter] = {}
def register_chat_format(name: str): def register_chat_format(name: str):
def decorator(f: ChatFormatter): def decorator(f: ChatFormatter):
_CHAT_FORMATS[name] = f def basic_create_chat_completion(
llama: llama.Llama,
messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
function_call: Optional[
Union[str, llama_types.ChatCompletionFunctionCall]
] = None,
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
max_tokens: int = 256,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
) -> Union[
llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]
]:
result = f(
messages=messages,
functions=functions,
function_call=function_call,
)
prompt = result.prompt
if result.stop is not None:
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
stop = stop + rstop
completion_or_chunks = llama.create_completion(
prompt=prompt,
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream=stream,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
grammar=grammar,
)
return _convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore
register_chat_completion_handler(name)(basic_create_chat_completion)
return f return f
return decorator return decorator
@ -320,3 +505,206 @@ def format_chatml(
_messages.append((_roles["assistant"], None)) _messages.append((_roles["assistant"], None))
_prompt = _format_chatml(system_message, _messages, _sep) _prompt = _format_chatml(system_message, _messages, _sep)
return ChatFormatterResponse(prompt=_prompt) return ChatFormatterResponse(prompt=_prompt)
@register_chat_completion_handler("functionary")
def functionary_chat_handler(
llama: llama.Llama,
messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
function_call: Optional[Union[str, llama_types.ChatCompletionFunctionCall]] = None,
temperature: float = 0.2,
top_p: float = 0.95,
top_k: int = 40,
stream: bool = False,
stop: Optional[Union[str, List[str]]] = [],
max_tokens: int = 256,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1,
tfs_z: float = 1.0,
mirostat_mode: int = 0,
mirostat_tau: float = 5.0,
mirostat_eta: float = 0.1,
model: Optional[str] = None,
logits_processor: Optional[llama.LogitsProcessorList] = None,
grammar: Optional[llama.LlamaGrammar] = None,
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
def generate_schema_from_functions(
functions: List[llama_types.ChatCompletionFunctions],
namespace: str = "functions",
):
"""
Convert functions schema to a schema that language models can understand.
"""
schema = (
"// Supported function definitions that should be called when necessary.\n"
)
schema += f"namespace {namespace} {{\n\n"
for function in functions:
# Convert a Function object to dict, if necessary
function_name = function["name"]
description = function.get("description", "")
schema += f"// {description}\n"
schema += f"type {function_name}"
parameters = function.get("parameters", None)
schema += " = (_: {\n"
required_params = parameters.get("required", [])
for param_name, param in parameters.get("properties", {}).items():
# Param Description
description = param.get("description")
if description is not None:
schema += f"// {description}\n"
# Param Name
schema += f"{param_name}"
if param_name not in required_params:
schema += "?"
# Param Type
param_type = param.get("type", "any")
if param_type == "integer":
param_type = "number"
if "enum" in param:
param_type = " | ".join([f'"{v}"' for v in param["enum"]])
schema += f": {param_type},\n"
schema += "}) => any;\n\n"
schema += f"}} // namespace {namespace}"
return schema
def prepare_messages_for_inference(
messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunctions]] = None,
):
all_messages: List[llama_types.ChatCompletionRequestMessage] = []
if functions is not None:
all_messages.append(
llama_types.ChatCompletionRequestMessage(
role="system", content=generate_schema_from_functions(functions)
)
)
all_messages.append(
llama_types.ChatCompletionRequestMessage(
role="system", content=SYSTEM_MESSAGE
)
)
for message in messages:
# Function call responses
if message["role"] == "function" and "name" in message:
message["name"] = f"functions.{message['name']}"
# Function call requests by assistant
if "function_call" in message:
message["function_call"][
"name"
] = f"functions.{message['function_call']['name']}"
all_messages.append(message)
all_messages.append(
llama_types.ChatCompletionRequestMessage(role="assistant", content=None)
)
def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
if msg["role"] == "system":
return f"system:\n{msg['content']}\n"
elif msg["role"] == "function" and "name" in msg:
return f"function name={msg['name']}:\n{msg['content']}\n"
elif msg["role"] == "function" and "function_call" in msg:
return f"function name={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n"
elif msg["role"] == "user":
if msg["content"] is None:
return "user:\n</s>"
else:
return f"user:\n</s>{msg['content']}\n"
elif msg["role"] == "assistant":
if msg["content"] is not None and "function_call" in msg:
return f"assistant:\n{msg['content']}\nassistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>"
elif "function_call" in msg:
return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>"
elif msg["content"] is None:
return "assistant"
else:
return f"assistant:\n{msg['content']}\n"
else:
raise ValueError(f"Unsupported role: {msg['role']}")
return "".join([message_to_str(msg) for msg in all_messages])
prompt = prepare_messages_for_inference(messages, functions)
if function_call is None and (functions is None or len(functions) == 0):
completion_or_completion_chunks = llama.create_completion(
prompt=prompt + ":\n",
temperature=temperature,
top_p=top_p,
top_k=top_k,
stream=stream,
stop=["user:", "</s>"],
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
grammar=grammar,
)
return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore
if function_call is None or (
isinstance(function_call, str) and function_call == "auto"
):
stop = "\n"
completion: llama_types.Completion = llama.create_completion(
prompt=prompt, stop=stop, stream=False
) # type: ignore
completion_text = completion["choices"][0]["text"]
# strip " to=functions." and ending ":"
function_call = completion_text[14:-1]
new_prompt = prompt + completion_text + stop
elif isinstance(function_call, str) and function_call != "none":
new_prompt = prompt + f"assistant:\n"
elif isinstance(function_call, dict):
new_prompt = prompt + f"assistant to={function_call['name']}:\n"
function_call = function_call["name"]
else:
new_prompt = prompt + f"assistant:\n"
completion: llama_types.Completion = llama.create_completion(
prompt=new_prompt, stop=["user:", "</s>"], stream=False
) # type: ignore
return llama_types.CreateChatCompletionResponse(
id="chat" + completion["id"],
object="chat.completion",
created=completion["created"],
model=completion["model"],
choices=[
{
"index": 0,
"message": {
"role": "function",
"content": None,
"function_call": {
"name": function_call,
"arguments": completion["choices"][0]["text"],
},
},
"finish_reason": "function_call",
}
],
usage=completion["usage"],
)

View file

@ -1,4 +1,5 @@
"""C++ implementation of the llama grammar parser.""" """Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp."""
# flake8: noqa # flake8: noqa
from pathlib import Path from pathlib import Path
import sys import sys
@ -1056,8 +1057,7 @@ def print_rule(
# fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str()); # fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END: if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END:
raise RuntimeError( raise RuntimeError(
"malformed rule, does not end with LLAMA_GRETYPE_END: " "malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id)
+ str(rule_id)
) )
print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ") print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ")
# for (size_t i = 0, end = rule.size() - 1; i < end; i++) { # for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
@ -1102,9 +1102,7 @@ def print_rule(
for i, elem in enumerate(rule[:-1]): for i, elem in enumerate(rule[:-1]):
case = elem.type # type: llama_gretype case = elem.type # type: llama_gretype
if case is llama_gretype.LLAMA_GRETYPE_END: if case is llama_gretype.LLAMA_GRETYPE_END:
raise RuntimeError( raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i))
"unexpected end of rule: " + str(rule_id) + "," + str(i)
)
elif case is llama_gretype.LLAMA_GRETYPE_ALT: elif case is llama_gretype.LLAMA_GRETYPE_ALT:
print("| ", file=file, end="") print("| ", file=file, end="")
elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF: elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF:
@ -1186,3 +1184,308 @@ def print_grammar(file: TextIO, state: parse_state) -> None:
f"{print_grammar.__name__}: error printing grammar: {err}", f"{print_grammar.__name__}: error printing grammar: {err}",
file=sys.stderr, file=sys.stderr,
) )
"""llama.cpp gbnf rules from vendor/llama.cpp/grammars"""
ARITHMETIC_GBNF = """\
root ::= (expr "=" ws term "\n")+
expr ::= term ([-+*/] term)*
term ::= ident | num | "(" ws expr ")" ws
ident ::= [a-z] [a-z0-9_]* ws
num ::= [0-9]+ ws
ws ::= [ \t\n]*
"""
C_GBNF = """\
root ::= (declaration)*
declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}"
dataType ::= "int" ws | "float" ws | "char" ws
identifier ::= [a-zA-Z_] [a-zA-Z_0-9]*
parameter ::= dataType identifier
statement ::=
( dataType identifier ws "=" ws expression ";" ) |
( identifier ws "=" ws expression ";" ) |
( identifier ws "(" argList? ")" ";" ) |
( "return" ws expression ";" ) |
( "while" "(" condition ")" "{" statement* "}" ) |
( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) |
( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) |
( singleLineComment ) |
( multiLineComment )
forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression
forUpdate ::= identifier ws "=" ws expression
condition ::= expression relationOperator expression
relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">")
expression ::= term (("+" | "-") term)*
term ::= factor(("*" | "/") factor)*
factor ::= identifier | number | unaryTerm | funcCall | parenExpression
unaryTerm ::= "-" factor
funcCall ::= identifier "(" argList? ")"
parenExpression ::= "(" ws expression ws ")"
argList ::= expression ("," ws expression)*
number ::= [0-9]+
singleLineComment ::= "//" [^\n]* "\n"
multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/"
ws ::= ([ \t\n]+)
"""
CHESS_GBNF = """\
root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws
object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws
array ::=
"[" ws (
value
("," ws value)*
)? "]" ws
string ::=
"\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
"""
JAPANESE_GBNF = """\
root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws
object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws
array ::=
"[" ws (
value
("," ws value)*
)? "]" ws
string ::=
"\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
"""
JSON_ARR_GBNF = """\
# This is the same as json.gbnf but we restrict whitespaces at the end of the root array
# Useful for generating JSON arrays
root ::= arr
value ::= object | array | string | number | ("true" | "false" | "null") ws
arr ::=
"[\n" ws (
value
(",\n" ws value)*
)? "]"
object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws
array ::=
"[" ws (
value
("," ws value)*
)? "]" ws
string ::=
"\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?
"""
JSON_GBNF = """\
root ::= object
value ::= object | array | string | number | ("true" | "false" | "null") ws
object ::=
"{" ws (
string ":" ws value
("," ws string ":" ws value)*
)? "}" ws
array ::=
"[" ws (
value
("," ws value)*
)? "]" ws
string ::=
"\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
)* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
# Optional space: by convention, applied in this grammar after literal chars when allowed
ws ::= ([ \t\n] ws)?"""
LIST_GBNF = """\
root ::= item+
# Excludes various line break characters
item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n"
"""
"""llama.cpp json-schema to grammar converter from vendor/llama.cpp/examples/json-schema-to-grammar.py"""
import json
import re
from typing import List, Optional
# whitespace is constrained to a single space char to prevent model "running away" in
# whitespace. Also maybe improves generation quality?
SPACE_RULE = '" "?'
PRIMITIVE_RULES = {
"boolean": '("true" | "false") space',
"number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
"integer": '("-"? ([0-9] | [1-9] [0-9]*)) space',
"string": r""" "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space """,
"null": '"null" space',
}
INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+")
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'}
class SchemaConverter:
def __init__(self, prop_order):
self._prop_order = prop_order
self._rules = {"space": SPACE_RULE}
def _format_literal(self, literal):
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal)
)
return f'"{escaped}"'
def _add_rule(self, name, rule):
esc_name = INVALID_RULE_CHARS_RE.sub("-", name)
if esc_name not in self._rules or self._rules[esc_name] == rule:
key = esc_name
else:
i = 0
while f"{esc_name}{i}" in self._rules:
i += 1
key = f"{esc_name}{i}"
self._rules[key] = rule
return key
def visit(self, schema, name):
schema_type = schema.get("type")
rule_name = name or "root"
if "oneOf" in schema or "anyOf" in schema:
rule = " | ".join(
(
self.visit(alt_schema, f'{name}{"-" if name else ""}{i}')
for i, alt_schema in enumerate(
schema.get("oneOf") or schema["anyOf"]
)
)
)
return self._add_rule(rule_name, rule)
elif "const" in schema:
return self._add_rule(rule_name, self._format_literal(schema["const"]))
elif "enum" in schema:
rule = " | ".join((self._format_literal(v) for v in schema["enum"]))
return self._add_rule(rule_name, rule)
elif schema_type == "object" and "properties" in schema:
# TODO: `required` keyword
prop_order = self._prop_order
prop_pairs = sorted(
schema["properties"].items(),
# sort by position in prop_order (if specified) then by key
key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]),
)
rule = '"{" space'
for i, (prop_name, prop_schema) in enumerate(prop_pairs):
prop_rule_name = self.visit(
prop_schema, f'{name}{"-" if name else ""}{prop_name}'
)
if i > 0:
rule += ' "," space'
rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}'
rule += ' "}" space'
return self._add_rule(rule_name, rule)
elif schema_type == "array" and "items" in schema:
# TODO `prefixItems` keyword
item_rule_name = self.visit(
schema["items"], f'{name}{"-" if name else ""}item'
)
rule = (
f'"[" space ({item_rule_name} ("," space {item_rule_name})*)? "]" space'
)
return self._add_rule(rule_name, rule)
else:
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
return self._add_rule(
"root" if rule_name == "root" else schema_type,
PRIMITIVE_RULES[schema_type],
)
def format_grammar(self):
return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items()))
def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None):
prop_order = prop_order or []
schema = json.load(schema)
prop_order = {name: idx for idx, name in enumerate(prop_order)}
converter = SchemaConverter(prop_order)
converter.visit(schema, "")
return converter.format_grammar()

View file

@ -743,21 +743,22 @@ async def create_embedding(
class ChatCompletionRequestMessage(BaseModel): class ChatCompletionRequestMessage(BaseModel):
role: Literal["system", "user", "assistant"] = Field( role: Literal["system", "user", "assistant", "function"] = Field(
default="user", description="The role of the message." default="user", description="The role of the message."
) )
content: str = Field(default="", description="The content of the message.") content: Optional[str] = Field(default="", description="The content of the message.")
from typing import Any
class CreateChatCompletionRequest(BaseModel): class CreateChatCompletionRequest(BaseModel):
messages: List[ChatCompletionRequestMessage] = Field( messages: List[Any] = Field(
default=[], description="A list of messages to generate completions for." default=[], description="A list of messages to generate completions for."
) )
functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field( functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field(
default=None, default=None,
description="A list of functions to apply to the generated completions.", description="A list of functions to apply to the generated completions.",
) )
function_call: Optional[Union[str, llama_cpp.ChatCompletionFunctionCall]] = Field( function_call: Optional[Union[Literal["auto", "none"], llama_cpp.ChatCompletionFunctionCallOption]] = Field(
default=None, default=None,
description="A function to apply to the generated completions.", description="A function to apply to the generated completions.",
) )