Add functionary support (#784)
* Add common grammars and json-schema-to-grammar utility function from llama.cpp * Pass functions to format function * Add basic functionary formatting * Add LlamaChatHandler for more complex chat use cases * Add function calling example notebook * Add support for regular chat completions alongside function calling
This commit is contained in:
parent
df31303a12
commit
3af7b21ff1
5 changed files with 936 additions and 99 deletions
225
examples/notebooks/Functions.ipynb
Normal file
225
examples/notebooks/Functions.ipynb
Normal file
|
@ -0,0 +1,225 @@
|
||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"{\n",
|
||||||
|
" \"id\": \"chatcmpl-a6db1bbb-a128-4c28-88fe-30717ec806b2\",\n",
|
||||||
|
" \"object\": \"chat.completion\",\n",
|
||||||
|
" \"created\": 1698989577,\n",
|
||||||
|
" \"model\": \"gpt-3.5-turbo-0613\",\n",
|
||||||
|
" \"choices\": [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"index\": 0,\n",
|
||||||
|
" \"message\": {\n",
|
||||||
|
" \"role\": \"assistant\",\n",
|
||||||
|
" \"content\": \"The current weather in Boston is sunny with a temperature of 72 degrees\"\n",
|
||||||
|
" },\n",
|
||||||
|
" \"finish_reason\": \"length\"\n",
|
||||||
|
" }\n",
|
||||||
|
" ],\n",
|
||||||
|
" \"usage\": {\n",
|
||||||
|
" \"prompt_tokens\": 135,\n",
|
||||||
|
" \"completion_tokens\": 16,\n",
|
||||||
|
" \"total_tokens\": 151\n",
|
||||||
|
" }\n",
|
||||||
|
"}\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import openai\n",
|
||||||
|
"import json\n",
|
||||||
|
"\n",
|
||||||
|
"openai.api_key = \"sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\" # can be anything\n",
|
||||||
|
"openai.api_base = \"http://100.64.159.73:8000/v1\"\n",
|
||||||
|
"\n",
|
||||||
|
"# Example dummy function hard coded to return the same weather\n",
|
||||||
|
"# In production, this could be your backend API or an external API\n",
|
||||||
|
"def get_current_weather(location, unit=\"fahrenheit\"):\n",
|
||||||
|
" \"\"\"Get the current weather in a given location\"\"\"\n",
|
||||||
|
" weather_info = {\n",
|
||||||
|
" \"location\": location,\n",
|
||||||
|
" \"temperature\": \"72\",\n",
|
||||||
|
" \"unit\": unit,\n",
|
||||||
|
" \"forecast\": [\"sunny\", \"windy\"],\n",
|
||||||
|
" }\n",
|
||||||
|
" return json.dumps(weather_info)\n",
|
||||||
|
"\n",
|
||||||
|
"def run_conversation():\n",
|
||||||
|
" # Step 1: send the conversation and available functions to GPT\n",
|
||||||
|
" messages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston?\"}]\n",
|
||||||
|
" functions = [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"name\": \"get_current_weather\",\n",
|
||||||
|
" \"description\": \"Get the current weather in a given location\",\n",
|
||||||
|
" \"parameters\": {\n",
|
||||||
|
" \"type\": \"object\",\n",
|
||||||
|
" \"properties\": {\n",
|
||||||
|
" \"location\": {\n",
|
||||||
|
" \"type\": \"string\",\n",
|
||||||
|
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
|
||||||
|
" },\n",
|
||||||
|
" \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
|
||||||
|
" },\n",
|
||||||
|
" \"required\": [\"location\"],\n",
|
||||||
|
" },\n",
|
||||||
|
" }\n",
|
||||||
|
" ]\n",
|
||||||
|
" response = openai.ChatCompletion.create(\n",
|
||||||
|
" model=\"gpt-3.5-turbo-0613\",\n",
|
||||||
|
" messages=messages,\n",
|
||||||
|
" functions=functions,\n",
|
||||||
|
" function_call=\"auto\", # auto is default, but we'll be explicit\n",
|
||||||
|
" )\n",
|
||||||
|
" response_message = response[\"choices\"][0][\"message\"]\n",
|
||||||
|
"\n",
|
||||||
|
" # Step 2: check if GPT wanted to call a function\n",
|
||||||
|
" if response_message.get(\"function_call\"):\n",
|
||||||
|
" # Step 3: call the function\n",
|
||||||
|
" # Note: the JSON response may not always be valid; be sure to handle errors\n",
|
||||||
|
" available_functions = {\n",
|
||||||
|
" \"get_current_weather\": get_current_weather,\n",
|
||||||
|
" } # only one function in this example, but you can have multiple\n",
|
||||||
|
" function_name = response_message[\"function_call\"][\"name\"]\n",
|
||||||
|
" fuction_to_call = available_functions[function_name]\n",
|
||||||
|
" function_args = json.loads(response_message[\"function_call\"][\"arguments\"])\n",
|
||||||
|
" function_response = fuction_to_call(\n",
|
||||||
|
" location=function_args.get(\"location\"),\n",
|
||||||
|
" unit=function_args.get(\"unit\"),\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" # Step 4: send the info on the function call and function response to GPT\n",
|
||||||
|
" messages.append(response_message) # extend conversation with assistant's reply\n",
|
||||||
|
" messages.append(\n",
|
||||||
|
" {\n",
|
||||||
|
" \"role\": \"function\",\n",
|
||||||
|
" \"name\": function_name,\n",
|
||||||
|
" \"content\": function_response,\n",
|
||||||
|
" }\n",
|
||||||
|
" ) # extend conversation with function response\n",
|
||||||
|
" second_response = openai.ChatCompletion.create(\n",
|
||||||
|
" model=\"gpt-3.5-turbo-0613\",\n",
|
||||||
|
" messages=messages,\n",
|
||||||
|
" ) # get a new response from GPT where it can see the function response\n",
|
||||||
|
" return second_response\n",
|
||||||
|
" else:\n",
|
||||||
|
" print(response)\n",
|
||||||
|
" print(\"No function\")\n",
|
||||||
|
"\n",
|
||||||
|
"print(run_conversation())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"name='Jason' age=25\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from pydantic import BaseModel\n",
|
||||||
|
"from instructor import patch\n",
|
||||||
|
"\n",
|
||||||
|
"patch()\n",
|
||||||
|
"\n",
|
||||||
|
"class UserDetail(BaseModel):\n",
|
||||||
|
" name: str\n",
|
||||||
|
" age: int\n",
|
||||||
|
"\n",
|
||||||
|
"user: UserDetail = openai.ChatCompletion.create(\n",
|
||||||
|
" model=\"gpt-3.5-turbo\",\n",
|
||||||
|
" response_model=UserDetail,\n",
|
||||||
|
" messages=[\n",
|
||||||
|
" {\"role\": \"user\", \"content\": \"Extract Jason is 25 years old\"},\n",
|
||||||
|
" ]\n",
|
||||||
|
")\n",
|
||||||
|
"print(user)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"{\n",
|
||||||
|
" \"id\": \"chatcmpl-59bcefad-9df5-4d6b-802c-5537b3e9044e\",\n",
|
||||||
|
" \"object\": \"chat.completion\",\n",
|
||||||
|
" \"created\": 1698989585,\n",
|
||||||
|
" \"model\": \"gpt-3.5-turbo-0613\",\n",
|
||||||
|
" \"choices\": [\n",
|
||||||
|
" {\n",
|
||||||
|
" \"index\": 0,\n",
|
||||||
|
" \"message\": {\n",
|
||||||
|
" \"role\": \"assistant\",\n",
|
||||||
|
" \"content\": \"I don't have up-to-date information on the current weather conditions\"\n",
|
||||||
|
" },\n",
|
||||||
|
" \"finish_reason\": \"length\"\n",
|
||||||
|
" }\n",
|
||||||
|
" ],\n",
|
||||||
|
" \"usage\": {\n",
|
||||||
|
" \"prompt_tokens\": 62,\n",
|
||||||
|
" \"completion_tokens\": 16,\n",
|
||||||
|
" \"total_tokens\": 78\n",
|
||||||
|
" }\n",
|
||||||
|
"}\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"response = openai.ChatCompletion.create(\n",
|
||||||
|
" model=\"gpt-3.5-turbo-0613\",\n",
|
||||||
|
" messages=[\n",
|
||||||
|
" {\"role\": \"user\", \"content\": \"What's the weather like in Boston?\"}\n",
|
||||||
|
" ]\n",
|
||||||
|
")\n",
|
||||||
|
"print(response)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": []
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "python-3.8.10",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.5+"
|
||||||
|
},
|
||||||
|
"orig_nbformat": 4
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
|
@ -24,7 +24,7 @@ import ctypes
|
||||||
from . import llama_cpp
|
from . import llama_cpp
|
||||||
from .llama_types import *
|
from .llama_types import *
|
||||||
from .llama_grammar import LlamaGrammar
|
from .llama_grammar import LlamaGrammar
|
||||||
from . import llama_chat_format
|
import llama_cpp.llama_chat_format as llama_chat_format
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
|
@ -1539,78 +1539,6 @@ class Llama:
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _convert_text_completion_to_chat(
|
|
||||||
self, completion: Completion
|
|
||||||
) -> ChatCompletion:
|
|
||||||
return {
|
|
||||||
"id": "chat" + completion["id"],
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": completion["created"],
|
|
||||||
"model": completion["model"],
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": completion["choices"][0]["text"],
|
|
||||||
},
|
|
||||||
"finish_reason": completion["choices"][0]["finish_reason"],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": completion["usage"],
|
|
||||||
}
|
|
||||||
|
|
||||||
def _convert_text_completion_chunks_to_chat(
|
|
||||||
self,
|
|
||||||
chunks: Iterator[CompletionChunk],
|
|
||||||
) -> Iterator[ChatCompletionChunk]:
|
|
||||||
for i, chunk in enumerate(chunks):
|
|
||||||
if i == 0:
|
|
||||||
yield {
|
|
||||||
"id": "chat" + chunk["id"],
|
|
||||||
"model": chunk["model"],
|
|
||||||
"created": chunk["created"],
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {
|
|
||||||
"role": "assistant",
|
|
||||||
},
|
|
||||||
"finish_reason": None,
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
yield {
|
|
||||||
"id": "chat" + chunk["id"],
|
|
||||||
"model": chunk["model"],
|
|
||||||
"created": chunk["created"],
|
|
||||||
"object": "chat.completion.chunk",
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {
|
|
||||||
"content": chunk["choices"][0]["text"],
|
|
||||||
}
|
|
||||||
if chunk["choices"][0]["finish_reason"] is None
|
|
||||||
else {},
|
|
||||||
"finish_reason": chunk["choices"][0]["finish_reason"],
|
|
||||||
}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
def _convert_completion_to_chat(
|
|
||||||
self,
|
|
||||||
completion_or_chunks: Union[Completion, Iterator[CompletionChunk]],
|
|
||||||
stream: bool = False,
|
|
||||||
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
|
|
||||||
if stream:
|
|
||||||
chunks: Iterator[CompletionChunk] = completion_or_chunks # type: ignore
|
|
||||||
return self._convert_text_completion_chunks_to_chat(chunks)
|
|
||||||
else:
|
|
||||||
completion: Completion = completion_or_chunks # type: ignore
|
|
||||||
return self._convert_text_completion_to_chat(completion)
|
|
||||||
|
|
||||||
def create_chat_completion(
|
def create_chat_completion(
|
||||||
self,
|
self,
|
||||||
messages: List[ChatCompletionRequestMessage],
|
messages: List[ChatCompletionRequestMessage],
|
||||||
|
@ -1648,19 +1576,12 @@ class Llama:
|
||||||
Returns:
|
Returns:
|
||||||
Generated chat completion or a stream of chat completion chunks.
|
Generated chat completion or a stream of chat completion chunks.
|
||||||
"""
|
"""
|
||||||
|
handler = llama_chat_format.get_chat_completion_handler(self.chat_format)
|
||||||
format = llama_chat_format.get_chat_format(self.chat_format)
|
return handler(
|
||||||
result = format(
|
self,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
)
|
functions=functions,
|
||||||
prompt = result.prompt
|
function_call=function_call,
|
||||||
if result.stop is not None:
|
|
||||||
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
|
|
||||||
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
|
|
||||||
stop = stop + rstop
|
|
||||||
|
|
||||||
completion_or_chunks = self.create_completion(
|
|
||||||
prompt=prompt,
|
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
|
@ -1678,7 +1599,6 @@ class Llama:
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
)
|
)
|
||||||
return self._convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore
|
|
||||||
|
|
||||||
def _free_model(self, *, _lbatch_free=llama_cpp._lib.llama_batch_free, _lfree_model=llama_cpp._lib.llama_free_model, _free=llama_cpp._lib.llama_free):
|
def _free_model(self, *, _lbatch_free=llama_cpp._lib.llama_batch_free, _lfree_model=llama_cpp._lib.llama_free_model, _free=llama_cpp._lib.llama_free):
|
||||||
batch = getattr(self, 'batch', None)
|
batch = getattr(self, 'batch', None)
|
||||||
|
|
|
@ -1,6 +1,53 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union, Protocol
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
|
||||||
from . import llama_types
|
from . import llama_types
|
||||||
|
from . import llama
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaChatCompletionHandler(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
llama: llama.Llama,
|
||||||
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||||
|
function_call: Optional[
|
||||||
|
Union[str, llama_types.ChatCompletionFunctionCall]
|
||||||
|
] = None,
|
||||||
|
temperature: float = 0.2,
|
||||||
|
top_p: float = 0.95,
|
||||||
|
top_k: int = 40,
|
||||||
|
stream: bool = False,
|
||||||
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
|
max_tokens: int = 256,
|
||||||
|
presence_penalty: float = 0.0,
|
||||||
|
frequency_penalty: float = 0.0,
|
||||||
|
repeat_penalty: float = 1.1,
|
||||||
|
tfs_z: float = 1.0,
|
||||||
|
mirostat_mode: int = 0,
|
||||||
|
mirostat_tau: float = 5.0,
|
||||||
|
mirostat_eta: float = 0.1,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||||
|
grammar: Optional[llama.LlamaGrammar] = None,
|
||||||
|
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
CHAT_HANDLERS: Dict[str, LlamaChatCompletionHandler] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_completion_handler(name: str) -> LlamaChatCompletionHandler:
|
||||||
|
return CHAT_HANDLERS[name]
|
||||||
|
|
||||||
|
|
||||||
|
def register_chat_completion_handler(name: str):
|
||||||
|
def decorator(f: LlamaChatCompletionHandler):
|
||||||
|
CHAT_HANDLERS[name] = f
|
||||||
|
return f
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def _get_system_message(
|
def _get_system_message(
|
||||||
|
@ -119,12 +166,150 @@ class ChatFormatter(Protocol):
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class BasicChatHandler:
|
||||||
|
def __init__(self, chat_format: str):
|
||||||
|
self.chat_format = chat_format
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_text_completion_to_chat(
|
||||||
|
completion: llama_types.Completion,
|
||||||
|
) -> llama_types.ChatCompletion:
|
||||||
|
return {
|
||||||
|
"id": "chat" + completion["id"],
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": completion["created"],
|
||||||
|
"model": completion["model"],
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": completion["choices"][0]["text"],
|
||||||
|
},
|
||||||
|
"finish_reason": completion["choices"][0]["finish_reason"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": completion["usage"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_text_completion_chunks_to_chat(
|
||||||
|
chunks: Iterator[llama_types.CompletionChunk],
|
||||||
|
) -> Iterator[llama_types.ChatCompletionChunk]:
|
||||||
|
for i, chunk in enumerate(chunks):
|
||||||
|
if i == 0:
|
||||||
|
yield {
|
||||||
|
"id": "chat" + chunk["id"],
|
||||||
|
"model": chunk["model"],
|
||||||
|
"created": chunk["created"],
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"role": "assistant",
|
||||||
|
},
|
||||||
|
"finish_reason": None,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
yield {
|
||||||
|
"id": "chat" + chunk["id"],
|
||||||
|
"model": chunk["model"],
|
||||||
|
"created": chunk["created"],
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {
|
||||||
|
"content": chunk["choices"][0]["text"],
|
||||||
|
}
|
||||||
|
if chunk["choices"][0]["finish_reason"] is None
|
||||||
|
else {},
|
||||||
|
"finish_reason": chunk["choices"][0]["finish_reason"],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_completion_to_chat(
|
||||||
|
completion_or_chunks: Union[
|
||||||
|
llama_types.Completion, Iterator[llama_types.CompletionChunk]
|
||||||
|
],
|
||||||
|
stream: bool = False,
|
||||||
|
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
|
||||||
|
if stream:
|
||||||
|
chunks: Iterator[llama_types.CompletionChunk] = completion_or_chunks # type: ignore
|
||||||
|
return _convert_text_completion_chunks_to_chat(chunks)
|
||||||
|
else:
|
||||||
|
completion: llama_types.Completion = completion_or_chunks # type: ignore
|
||||||
|
return _convert_text_completion_to_chat(completion)
|
||||||
|
|
||||||
|
|
||||||
_CHAT_FORMATS: Dict[str, ChatFormatter] = {}
|
_CHAT_FORMATS: Dict[str, ChatFormatter] = {}
|
||||||
|
|
||||||
|
|
||||||
def register_chat_format(name: str):
|
def register_chat_format(name: str):
|
||||||
def decorator(f: ChatFormatter):
|
def decorator(f: ChatFormatter):
|
||||||
_CHAT_FORMATS[name] = f
|
def basic_create_chat_completion(
|
||||||
|
llama: llama.Llama,
|
||||||
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||||
|
function_call: Optional[
|
||||||
|
Union[str, llama_types.ChatCompletionFunctionCall]
|
||||||
|
] = None,
|
||||||
|
temperature: float = 0.2,
|
||||||
|
top_p: float = 0.95,
|
||||||
|
top_k: int = 40,
|
||||||
|
stream: bool = False,
|
||||||
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
|
max_tokens: int = 256,
|
||||||
|
presence_penalty: float = 0.0,
|
||||||
|
frequency_penalty: float = 0.0,
|
||||||
|
repeat_penalty: float = 1.1,
|
||||||
|
tfs_z: float = 1.0,
|
||||||
|
mirostat_mode: int = 0,
|
||||||
|
mirostat_tau: float = 5.0,
|
||||||
|
mirostat_eta: float = 0.1,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||||
|
grammar: Optional[llama.LlamaGrammar] = None,
|
||||||
|
) -> Union[
|
||||||
|
llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]
|
||||||
|
]:
|
||||||
|
result = f(
|
||||||
|
messages=messages,
|
||||||
|
functions=functions,
|
||||||
|
function_call=function_call,
|
||||||
|
)
|
||||||
|
prompt = result.prompt
|
||||||
|
if result.stop is not None:
|
||||||
|
stop = [] if stop is None else [stop] if isinstance(stop, str) else stop
|
||||||
|
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
|
||||||
|
stop = stop + rstop
|
||||||
|
|
||||||
|
completion_or_chunks = llama.create_completion(
|
||||||
|
prompt=prompt,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
stream=stream,
|
||||||
|
stop=stop,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
repeat_penalty=repeat_penalty,
|
||||||
|
tfs_z=tfs_z,
|
||||||
|
mirostat_mode=mirostat_mode,
|
||||||
|
mirostat_tau=mirostat_tau,
|
||||||
|
mirostat_eta=mirostat_eta,
|
||||||
|
model=model,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
grammar=grammar,
|
||||||
|
)
|
||||||
|
return _convert_completion_to_chat(completion_or_chunks, stream=stream) # type: ignore
|
||||||
|
|
||||||
|
register_chat_completion_handler(name)(basic_create_chat_completion)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
@ -320,3 +505,206 @@ def format_chatml(
|
||||||
_messages.append((_roles["assistant"], None))
|
_messages.append((_roles["assistant"], None))
|
||||||
_prompt = _format_chatml(system_message, _messages, _sep)
|
_prompt = _format_chatml(system_message, _messages, _sep)
|
||||||
return ChatFormatterResponse(prompt=_prompt)
|
return ChatFormatterResponse(prompt=_prompt)
|
||||||
|
|
||||||
|
|
||||||
|
@register_chat_completion_handler("functionary")
|
||||||
|
def functionary_chat_handler(
|
||||||
|
llama: llama.Llama,
|
||||||
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||||
|
function_call: Optional[Union[str, llama_types.ChatCompletionFunctionCall]] = None,
|
||||||
|
temperature: float = 0.2,
|
||||||
|
top_p: float = 0.95,
|
||||||
|
top_k: int = 40,
|
||||||
|
stream: bool = False,
|
||||||
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
|
max_tokens: int = 256,
|
||||||
|
presence_penalty: float = 0.0,
|
||||||
|
frequency_penalty: float = 0.0,
|
||||||
|
repeat_penalty: float = 1.1,
|
||||||
|
tfs_z: float = 1.0,
|
||||||
|
mirostat_mode: int = 0,
|
||||||
|
mirostat_tau: float = 5.0,
|
||||||
|
mirostat_eta: float = 0.1,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
logits_processor: Optional[llama.LogitsProcessorList] = None,
|
||||||
|
grammar: Optional[llama.LlamaGrammar] = None,
|
||||||
|
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
|
||||||
|
SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
|
||||||
|
|
||||||
|
def generate_schema_from_functions(
|
||||||
|
functions: List[llama_types.ChatCompletionFunctions],
|
||||||
|
namespace: str = "functions",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Convert functions schema to a schema that language models can understand.
|
||||||
|
"""
|
||||||
|
|
||||||
|
schema = (
|
||||||
|
"// Supported function definitions that should be called when necessary.\n"
|
||||||
|
)
|
||||||
|
schema += f"namespace {namespace} {{\n\n"
|
||||||
|
|
||||||
|
for function in functions:
|
||||||
|
# Convert a Function object to dict, if necessary
|
||||||
|
function_name = function["name"]
|
||||||
|
description = function.get("description", "")
|
||||||
|
schema += f"// {description}\n"
|
||||||
|
schema += f"type {function_name}"
|
||||||
|
|
||||||
|
parameters = function.get("parameters", None)
|
||||||
|
schema += " = (_: {\n"
|
||||||
|
required_params = parameters.get("required", [])
|
||||||
|
for param_name, param in parameters.get("properties", {}).items():
|
||||||
|
# Param Description
|
||||||
|
description = param.get("description")
|
||||||
|
if description is not None:
|
||||||
|
schema += f"// {description}\n"
|
||||||
|
|
||||||
|
# Param Name
|
||||||
|
schema += f"{param_name}"
|
||||||
|
if param_name not in required_params:
|
||||||
|
schema += "?"
|
||||||
|
|
||||||
|
# Param Type
|
||||||
|
param_type = param.get("type", "any")
|
||||||
|
if param_type == "integer":
|
||||||
|
param_type = "number"
|
||||||
|
if "enum" in param:
|
||||||
|
param_type = " | ".join([f'"{v}"' for v in param["enum"]])
|
||||||
|
schema += f": {param_type},\n"
|
||||||
|
|
||||||
|
schema += "}) => any;\n\n"
|
||||||
|
|
||||||
|
schema += f"}} // namespace {namespace}"
|
||||||
|
|
||||||
|
return schema
|
||||||
|
|
||||||
|
def prepare_messages_for_inference(
|
||||||
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
|
functions: Optional[List[llama_types.ChatCompletionFunctions]] = None,
|
||||||
|
):
|
||||||
|
all_messages: List[llama_types.ChatCompletionRequestMessage] = []
|
||||||
|
if functions is not None:
|
||||||
|
all_messages.append(
|
||||||
|
llama_types.ChatCompletionRequestMessage(
|
||||||
|
role="system", content=generate_schema_from_functions(functions)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
all_messages.append(
|
||||||
|
llama_types.ChatCompletionRequestMessage(
|
||||||
|
role="system", content=SYSTEM_MESSAGE
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
# Function call responses
|
||||||
|
if message["role"] == "function" and "name" in message:
|
||||||
|
message["name"] = f"functions.{message['name']}"
|
||||||
|
# Function call requests by assistant
|
||||||
|
if "function_call" in message:
|
||||||
|
message["function_call"][
|
||||||
|
"name"
|
||||||
|
] = f"functions.{message['function_call']['name']}"
|
||||||
|
all_messages.append(message)
|
||||||
|
|
||||||
|
all_messages.append(
|
||||||
|
llama_types.ChatCompletionRequestMessage(role="assistant", content=None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
|
||||||
|
if msg["role"] == "system":
|
||||||
|
return f"system:\n{msg['content']}\n"
|
||||||
|
|
||||||
|
elif msg["role"] == "function" and "name" in msg:
|
||||||
|
return f"function name={msg['name']}:\n{msg['content']}\n"
|
||||||
|
elif msg["role"] == "function" and "function_call" in msg:
|
||||||
|
return f"function name={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n"
|
||||||
|
elif msg["role"] == "user":
|
||||||
|
if msg["content"] is None:
|
||||||
|
return "user:\n</s>"
|
||||||
|
else:
|
||||||
|
return f"user:\n</s>{msg['content']}\n"
|
||||||
|
elif msg["role"] == "assistant":
|
||||||
|
if msg["content"] is not None and "function_call" in msg:
|
||||||
|
return f"assistant:\n{msg['content']}\nassistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>"
|
||||||
|
elif "function_call" in msg:
|
||||||
|
return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>"
|
||||||
|
elif msg["content"] is None:
|
||||||
|
return "assistant"
|
||||||
|
else:
|
||||||
|
return f"assistant:\n{msg['content']}\n"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported role: {msg['role']}")
|
||||||
|
|
||||||
|
return "".join([message_to_str(msg) for msg in all_messages])
|
||||||
|
|
||||||
|
prompt = prepare_messages_for_inference(messages, functions)
|
||||||
|
|
||||||
|
if function_call is None and (functions is None or len(functions) == 0):
|
||||||
|
completion_or_completion_chunks = llama.create_completion(
|
||||||
|
prompt=prompt + ":\n",
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
stream=stream,
|
||||||
|
stop=["user:", "</s>"],
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
repeat_penalty=repeat_penalty,
|
||||||
|
tfs_z=tfs_z,
|
||||||
|
mirostat_mode=mirostat_mode,
|
||||||
|
mirostat_tau=mirostat_tau,
|
||||||
|
mirostat_eta=mirostat_eta,
|
||||||
|
model=model,
|
||||||
|
logits_processor=logits_processor,
|
||||||
|
grammar=grammar,
|
||||||
|
)
|
||||||
|
return _convert_completion_to_chat(completion_or_completion_chunks, stream=stream) # type: ignore
|
||||||
|
|
||||||
|
if function_call is None or (
|
||||||
|
isinstance(function_call, str) and function_call == "auto"
|
||||||
|
):
|
||||||
|
stop = "\n"
|
||||||
|
completion: llama_types.Completion = llama.create_completion(
|
||||||
|
prompt=prompt, stop=stop, stream=False
|
||||||
|
) # type: ignore
|
||||||
|
completion_text = completion["choices"][0]["text"]
|
||||||
|
# strip " to=functions." and ending ":"
|
||||||
|
function_call = completion_text[14:-1]
|
||||||
|
new_prompt = prompt + completion_text + stop
|
||||||
|
elif isinstance(function_call, str) and function_call != "none":
|
||||||
|
new_prompt = prompt + f"assistant:\n"
|
||||||
|
elif isinstance(function_call, dict):
|
||||||
|
new_prompt = prompt + f"assistant to={function_call['name']}:\n"
|
||||||
|
function_call = function_call["name"]
|
||||||
|
else:
|
||||||
|
new_prompt = prompt + f"assistant:\n"
|
||||||
|
|
||||||
|
completion: llama_types.Completion = llama.create_completion(
|
||||||
|
prompt=new_prompt, stop=["user:", "</s>"], stream=False
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
return llama_types.CreateChatCompletionResponse(
|
||||||
|
id="chat" + completion["id"],
|
||||||
|
object="chat.completion",
|
||||||
|
created=completion["created"],
|
||||||
|
model=completion["model"],
|
||||||
|
choices=[
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "function",
|
||||||
|
"content": None,
|
||||||
|
"function_call": {
|
||||||
|
"name": function_call,
|
||||||
|
"arguments": completion["choices"][0]["text"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"finish_reason": "function_call",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
usage=completion["usage"],
|
||||||
|
)
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
"""C++ implementation of the llama grammar parser."""
|
"""Python implementation of llama grammar parser directly translated from C++ source file in vendor/llama.cpp/common/grammar-parser.cpp."""
|
||||||
|
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys
|
import sys
|
||||||
|
@ -1056,8 +1057,7 @@ def print_rule(
|
||||||
# fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
# fprintf(file, "%s ::= ", symbol_id_names.at(rule_id).c_str());
|
||||||
if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END:
|
if rule.empty() or rule.back().type != llama_gretype.LLAMA_GRETYPE_END:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"malformed rule, does not end with LLAMA_GRETYPE_END: "
|
"malformed rule, does not end with LLAMA_GRETYPE_END: " + str(rule_id)
|
||||||
+ str(rule_id)
|
|
||||||
)
|
)
|
||||||
print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ")
|
print(f"{symbol_id_names.at(rule_id)} ::=", file=file, end=" ")
|
||||||
# for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
# for (size_t i = 0, end = rule.size() - 1; i < end; i++) {
|
||||||
|
@ -1102,9 +1102,7 @@ def print_rule(
|
||||||
for i, elem in enumerate(rule[:-1]):
|
for i, elem in enumerate(rule[:-1]):
|
||||||
case = elem.type # type: llama_gretype
|
case = elem.type # type: llama_gretype
|
||||||
if case is llama_gretype.LLAMA_GRETYPE_END:
|
if case is llama_gretype.LLAMA_GRETYPE_END:
|
||||||
raise RuntimeError(
|
raise RuntimeError("unexpected end of rule: " + str(rule_id) + "," + str(i))
|
||||||
"unexpected end of rule: " + str(rule_id) + "," + str(i)
|
|
||||||
)
|
|
||||||
elif case is llama_gretype.LLAMA_GRETYPE_ALT:
|
elif case is llama_gretype.LLAMA_GRETYPE_ALT:
|
||||||
print("| ", file=file, end="")
|
print("| ", file=file, end="")
|
||||||
elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF:
|
elif case is llama_gretype.LLAMA_GRETYPE_RULE_REF:
|
||||||
|
@ -1186,3 +1184,308 @@ def print_grammar(file: TextIO, state: parse_state) -> None:
|
||||||
f"{print_grammar.__name__}: error printing grammar: {err}",
|
f"{print_grammar.__name__}: error printing grammar: {err}",
|
||||||
file=sys.stderr,
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
"""llama.cpp gbnf rules from vendor/llama.cpp/grammars"""
|
||||||
|
|
||||||
|
ARITHMETIC_GBNF = """\
|
||||||
|
root ::= (expr "=" ws term "\n")+
|
||||||
|
expr ::= term ([-+*/] term)*
|
||||||
|
term ::= ident | num | "(" ws expr ")" ws
|
||||||
|
ident ::= [a-z] [a-z0-9_]* ws
|
||||||
|
num ::= [0-9]+ ws
|
||||||
|
ws ::= [ \t\n]*
|
||||||
|
"""
|
||||||
|
|
||||||
|
C_GBNF = """\
|
||||||
|
root ::= (declaration)*
|
||||||
|
|
||||||
|
declaration ::= dataType identifier "(" parameter? ")" "{" statement* "}"
|
||||||
|
|
||||||
|
dataType ::= "int" ws | "float" ws | "char" ws
|
||||||
|
identifier ::= [a-zA-Z_] [a-zA-Z_0-9]*
|
||||||
|
|
||||||
|
parameter ::= dataType identifier
|
||||||
|
|
||||||
|
statement ::=
|
||||||
|
( dataType identifier ws "=" ws expression ";" ) |
|
||||||
|
( identifier ws "=" ws expression ";" ) |
|
||||||
|
( identifier ws "(" argList? ")" ";" ) |
|
||||||
|
( "return" ws expression ";" ) |
|
||||||
|
( "while" "(" condition ")" "{" statement* "}" ) |
|
||||||
|
( "for" "(" forInit ";" ws condition ";" ws forUpdate ")" "{" statement* "}" ) |
|
||||||
|
( "if" "(" condition ")" "{" statement* "}" ("else" "{" statement* "}")? ) |
|
||||||
|
( singleLineComment ) |
|
||||||
|
( multiLineComment )
|
||||||
|
|
||||||
|
forInit ::= dataType identifier ws "=" ws expression | identifier ws "=" ws expression
|
||||||
|
forUpdate ::= identifier ws "=" ws expression
|
||||||
|
|
||||||
|
condition ::= expression relationOperator expression
|
||||||
|
relationOperator ::= ("<=" | "<" | "==" | "!=" | ">=" | ">")
|
||||||
|
|
||||||
|
expression ::= term (("+" | "-") term)*
|
||||||
|
term ::= factor(("*" | "/") factor)*
|
||||||
|
|
||||||
|
factor ::= identifier | number | unaryTerm | funcCall | parenExpression
|
||||||
|
unaryTerm ::= "-" factor
|
||||||
|
funcCall ::= identifier "(" argList? ")"
|
||||||
|
parenExpression ::= "(" ws expression ws ")"
|
||||||
|
|
||||||
|
argList ::= expression ("," ws expression)*
|
||||||
|
|
||||||
|
number ::= [0-9]+
|
||||||
|
|
||||||
|
singleLineComment ::= "//" [^\n]* "\n"
|
||||||
|
multiLineComment ::= "/*" ( [^*] | ("*" [^/]) )* "*/"
|
||||||
|
|
||||||
|
ws ::= ([ \t\n]+)
|
||||||
|
"""
|
||||||
|
|
||||||
|
CHESS_GBNF = """\
|
||||||
|
root ::= object
|
||||||
|
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||||
|
|
||||||
|
object ::=
|
||||||
|
"{" ws (
|
||||||
|
string ":" ws value
|
||||||
|
("," ws string ":" ws value)*
|
||||||
|
)? "}" ws
|
||||||
|
|
||||||
|
array ::=
|
||||||
|
"[" ws (
|
||||||
|
value
|
||||||
|
("," ws value)*
|
||||||
|
)? "]" ws
|
||||||
|
|
||||||
|
string ::=
|
||||||
|
"\"" (
|
||||||
|
[^"\\] |
|
||||||
|
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||||
|
)* "\"" ws
|
||||||
|
|
||||||
|
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||||
|
|
||||||
|
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||||
|
ws ::= ([ \t\n] ws)?
|
||||||
|
"""
|
||||||
|
|
||||||
|
JAPANESE_GBNF = """\
|
||||||
|
root ::= object
|
||||||
|
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||||
|
|
||||||
|
object ::=
|
||||||
|
"{" ws (
|
||||||
|
string ":" ws value
|
||||||
|
("," ws string ":" ws value)*
|
||||||
|
)? "}" ws
|
||||||
|
|
||||||
|
array ::=
|
||||||
|
"[" ws (
|
||||||
|
value
|
||||||
|
("," ws value)*
|
||||||
|
)? "]" ws
|
||||||
|
|
||||||
|
string ::=
|
||||||
|
"\"" (
|
||||||
|
[^"\\] |
|
||||||
|
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||||
|
)* "\"" ws
|
||||||
|
|
||||||
|
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||||
|
|
||||||
|
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||||
|
ws ::= ([ \t\n] ws)?
|
||||||
|
"""
|
||||||
|
|
||||||
|
JSON_ARR_GBNF = """\
|
||||||
|
# This is the same as json.gbnf but we restrict whitespaces at the end of the root array
|
||||||
|
# Useful for generating JSON arrays
|
||||||
|
|
||||||
|
root ::= arr
|
||||||
|
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||||
|
|
||||||
|
arr ::=
|
||||||
|
"[\n" ws (
|
||||||
|
value
|
||||||
|
(",\n" ws value)*
|
||||||
|
)? "]"
|
||||||
|
|
||||||
|
object ::=
|
||||||
|
"{" ws (
|
||||||
|
string ":" ws value
|
||||||
|
("," ws string ":" ws value)*
|
||||||
|
)? "}" ws
|
||||||
|
|
||||||
|
array ::=
|
||||||
|
"[" ws (
|
||||||
|
value
|
||||||
|
("," ws value)*
|
||||||
|
)? "]" ws
|
||||||
|
|
||||||
|
string ::=
|
||||||
|
"\"" (
|
||||||
|
[^"\\] |
|
||||||
|
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||||
|
)* "\"" ws
|
||||||
|
|
||||||
|
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||||
|
|
||||||
|
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||||
|
ws ::= ([ \t\n] ws)?
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
JSON_GBNF = """\
|
||||||
|
root ::= object
|
||||||
|
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||||
|
|
||||||
|
object ::=
|
||||||
|
"{" ws (
|
||||||
|
string ":" ws value
|
||||||
|
("," ws string ":" ws value)*
|
||||||
|
)? "}" ws
|
||||||
|
|
||||||
|
array ::=
|
||||||
|
"[" ws (
|
||||||
|
value
|
||||||
|
("," ws value)*
|
||||||
|
)? "]" ws
|
||||||
|
|
||||||
|
string ::=
|
||||||
|
"\"" (
|
||||||
|
[^"\\] |
|
||||||
|
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||||
|
)* "\"" ws
|
||||||
|
|
||||||
|
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||||
|
|
||||||
|
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||||
|
ws ::= ([ \t\n] ws)?"""
|
||||||
|
|
||||||
|
LIST_GBNF = """\
|
||||||
|
root ::= item+
|
||||||
|
|
||||||
|
# Excludes various line break characters
|
||||||
|
item ::= "- " [^\r\n\x0b\x0c\x85\u2028\u2029]+ "\n"
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""llama.cpp json-schema to grammar converter from vendor/llama.cpp/examples/json-schema-to-grammar.py"""
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
# whitespace is constrained to a single space char to prevent model "running away" in
|
||||||
|
# whitespace. Also maybe improves generation quality?
|
||||||
|
SPACE_RULE = '" "?'
|
||||||
|
|
||||||
|
PRIMITIVE_RULES = {
|
||||||
|
"boolean": '("true" | "false") space',
|
||||||
|
"number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
|
||||||
|
"integer": '("-"? ([0-9] | [1-9] [0-9]*)) space',
|
||||||
|
"string": r""" "\"" (
|
||||||
|
[^"\\] |
|
||||||
|
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
|
||||||
|
)* "\"" space """,
|
||||||
|
"null": '"null" space',
|
||||||
|
}
|
||||||
|
|
||||||
|
INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+")
|
||||||
|
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
|
||||||
|
GRAMMAR_LITERAL_ESCAPES = {"\r": "\\r", "\n": "\\n", '"': '\\"'}
|
||||||
|
|
||||||
|
|
||||||
|
class SchemaConverter:
|
||||||
|
def __init__(self, prop_order):
|
||||||
|
self._prop_order = prop_order
|
||||||
|
self._rules = {"space": SPACE_RULE}
|
||||||
|
|
||||||
|
def _format_literal(self, literal):
|
||||||
|
escaped = GRAMMAR_LITERAL_ESCAPE_RE.sub(
|
||||||
|
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal)
|
||||||
|
)
|
||||||
|
return f'"{escaped}"'
|
||||||
|
|
||||||
|
def _add_rule(self, name, rule):
|
||||||
|
esc_name = INVALID_RULE_CHARS_RE.sub("-", name)
|
||||||
|
if esc_name not in self._rules or self._rules[esc_name] == rule:
|
||||||
|
key = esc_name
|
||||||
|
else:
|
||||||
|
i = 0
|
||||||
|
while f"{esc_name}{i}" in self._rules:
|
||||||
|
i += 1
|
||||||
|
key = f"{esc_name}{i}"
|
||||||
|
self._rules[key] = rule
|
||||||
|
return key
|
||||||
|
|
||||||
|
def visit(self, schema, name):
|
||||||
|
schema_type = schema.get("type")
|
||||||
|
rule_name = name or "root"
|
||||||
|
|
||||||
|
if "oneOf" in schema or "anyOf" in schema:
|
||||||
|
rule = " | ".join(
|
||||||
|
(
|
||||||
|
self.visit(alt_schema, f'{name}{"-" if name else ""}{i}')
|
||||||
|
for i, alt_schema in enumerate(
|
||||||
|
schema.get("oneOf") or schema["anyOf"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return self._add_rule(rule_name, rule)
|
||||||
|
|
||||||
|
elif "const" in schema:
|
||||||
|
return self._add_rule(rule_name, self._format_literal(schema["const"]))
|
||||||
|
|
||||||
|
elif "enum" in schema:
|
||||||
|
rule = " | ".join((self._format_literal(v) for v in schema["enum"]))
|
||||||
|
return self._add_rule(rule_name, rule)
|
||||||
|
|
||||||
|
elif schema_type == "object" and "properties" in schema:
|
||||||
|
# TODO: `required` keyword
|
||||||
|
prop_order = self._prop_order
|
||||||
|
prop_pairs = sorted(
|
||||||
|
schema["properties"].items(),
|
||||||
|
# sort by position in prop_order (if specified) then by key
|
||||||
|
key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]),
|
||||||
|
)
|
||||||
|
|
||||||
|
rule = '"{" space'
|
||||||
|
for i, (prop_name, prop_schema) in enumerate(prop_pairs):
|
||||||
|
prop_rule_name = self.visit(
|
||||||
|
prop_schema, f'{name}{"-" if name else ""}{prop_name}'
|
||||||
|
)
|
||||||
|
if i > 0:
|
||||||
|
rule += ' "," space'
|
||||||
|
rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}'
|
||||||
|
rule += ' "}" space'
|
||||||
|
|
||||||
|
return self._add_rule(rule_name, rule)
|
||||||
|
|
||||||
|
elif schema_type == "array" and "items" in schema:
|
||||||
|
# TODO `prefixItems` keyword
|
||||||
|
item_rule_name = self.visit(
|
||||||
|
schema["items"], f'{name}{"-" if name else ""}item'
|
||||||
|
)
|
||||||
|
rule = (
|
||||||
|
f'"[" space ({item_rule_name} ("," space {item_rule_name})*)? "]" space'
|
||||||
|
)
|
||||||
|
return self._add_rule(rule_name, rule)
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
|
||||||
|
return self._add_rule(
|
||||||
|
"root" if rule_name == "root" else schema_type,
|
||||||
|
PRIMITIVE_RULES[schema_type],
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_grammar(self):
|
||||||
|
return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items()))
|
||||||
|
|
||||||
|
|
||||||
|
def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None):
|
||||||
|
prop_order = prop_order or []
|
||||||
|
schema = json.load(schema)
|
||||||
|
prop_order = {name: idx for idx, name in enumerate(prop_order)}
|
||||||
|
converter = SchemaConverter(prop_order)
|
||||||
|
converter.visit(schema, "")
|
||||||
|
return converter.format_grammar()
|
||||||
|
|
|
@ -743,21 +743,22 @@ async def create_embedding(
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequestMessage(BaseModel):
|
class ChatCompletionRequestMessage(BaseModel):
|
||||||
role: Literal["system", "user", "assistant"] = Field(
|
role: Literal["system", "user", "assistant", "function"] = Field(
|
||||||
default="user", description="The role of the message."
|
default="user", description="The role of the message."
|
||||||
)
|
)
|
||||||
content: str = Field(default="", description="The content of the message.")
|
content: Optional[str] = Field(default="", description="The content of the message.")
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
class CreateChatCompletionRequest(BaseModel):
|
class CreateChatCompletionRequest(BaseModel):
|
||||||
messages: List[ChatCompletionRequestMessage] = Field(
|
messages: List[Any] = Field(
|
||||||
default=[], description="A list of messages to generate completions for."
|
default=[], description="A list of messages to generate completions for."
|
||||||
)
|
)
|
||||||
functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field(
|
functions: Optional[List[llama_cpp.ChatCompletionFunction]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="A list of functions to apply to the generated completions.",
|
description="A list of functions to apply to the generated completions.",
|
||||||
)
|
)
|
||||||
function_call: Optional[Union[str, llama_cpp.ChatCompletionFunctionCall]] = Field(
|
function_call: Optional[Union[Literal["auto", "none"], llama_cpp.ChatCompletionFunctionCallOption]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="A function to apply to the generated completions.",
|
description="A function to apply to the generated completions.",
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue