Update functionary for new OpenAI API

This commit is contained in:
Andrei Betlen 2023-11-10 02:51:58 -05:00
parent 17da8fb446
commit 1b376c62b7
2 changed files with 437 additions and 162 deletions

View file

@ -2,34 +2,14 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 29,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"{\n", "ChatCompletion(id='chatcmpl-b6dcbb47-1120-4761-8cd9-83542c97647b', choices=[Choice(finish_reason='stop', index=0, message=ChatCompletionMessage(content=\"The current temperature in San Francisco is 72 degrees Fahrenheit. It's a sunny day with clear skies, making it perfect for outdoor activities.\\n \", role='assistant', function_call=None, tool_calls=None))], created=1699602158, model='gpt-3.5-turbo-1106', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=38, prompt_tokens=135, total_tokens=173))\n"
" \"id\": \"chatcmpl-a6db1bbb-a128-4c28-88fe-30717ec806b2\",\n",
" \"object\": \"chat.completion\",\n",
" \"created\": 1698989577,\n",
" \"model\": \"gpt-3.5-turbo-0613\",\n",
" \"choices\": [\n",
" {\n",
" \"index\": 0,\n",
" \"message\": {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"The current weather in Boston is sunny with a temperature of 72 degrees\"\n",
" },\n",
" \"finish_reason\": \"length\"\n",
" }\n",
" ],\n",
" \"usage\": {\n",
" \"prompt_tokens\": 135,\n",
" \"completion_tokens\": 16,\n",
" \"total_tokens\": 151\n",
" }\n",
"}\n"
] ]
} }
], ],
@ -37,26 +17,32 @@
"import openai\n", "import openai\n",
"import json\n", "import json\n",
"\n", "\n",
"openai.api_key = \"sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\" # can be anything\n", "\n",
"openai.api_base = \"http://100.64.159.73:8000/v1\"\n", "client = openai.OpenAI(\n",
" api_key = \"sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\", # can be anything\n",
" base_url = \"http://100.64.159.73:8000/v1\"\n",
")\n",
"\n", "\n",
"# Example dummy function hard coded to return the same weather\n", "# Example dummy function hard coded to return the same weather\n",
"# In production, this could be your backend API or an external API\n", "# In production, this could be your backend API or an external API\n",
"def get_current_weather(location, unit=\"fahrenheit\"):\n", "def get_current_weather(location, unit=\"fahrenheit\"):\n",
" \"\"\"Get the current weather in a given location\"\"\"\n", " \"\"\"Get the current weather in a given location\"\"\"\n",
" weather_info = {\n", " if \"tokyo\" in location.lower():\n",
" \"location\": location,\n", " return json.dumps({\"location\": \"Tokyo\", \"temperature\": \"10\", \"unit\": \"celsius\"})\n",
" \"temperature\": \"72\",\n", " elif \"san francisco\" in location.lower():\n",
" \"unit\": unit,\n", " return json.dumps({\"location\": \"San Francisco\", \"temperature\": \"72\", \"unit\": \"fahrenheit\"})\n",
" \"forecast\": [\"sunny\", \"windy\"],\n", " elif \"paris\" in location.lower():\n",
" }\n", " return json.dumps({\"location\": \"Paris\", \"temperature\": \"22\", \"unit\": \"celsius\"})\n",
" return json.dumps(weather_info)\n", " else:\n",
" return json.dumps({\"location\": location, \"temperature\": \"unknown\"})\n",
"\n", "\n",
"def run_conversation():\n", "def run_conversation():\n",
" # Step 1: send the conversation and available functions to GPT\n", " # Step 1: send the conversation and available functions to the model\n",
" messages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston?\"}]\n", " messages = [{\"role\": \"user\", \"content\": \"What's the weather like in San Francisco, Tokyo, and Paris?\"}]\n",
" functions = [\n", " tools = [\n",
" {\n", " {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_current_weather\",\n", " \"name\": \"get_current_weather\",\n",
" \"description\": \"Get the current weather in a given location\",\n", " \"description\": \"Get the current weather in a given location\",\n",
" \"parameters\": {\n", " \"parameters\": {\n",
@ -70,55 +56,53 @@
" },\n", " },\n",
" \"required\": [\"location\"],\n", " \"required\": [\"location\"],\n",
" },\n", " },\n",
" },\n",
" }\n", " }\n",
" ]\n", " ]\n",
" response = openai.ChatCompletion.create(\n", " response = client.chat.completions.create(\n",
" model=\"gpt-3.5-turbo-0613\",\n", " model=\"gpt-3.5-turbo-1106\",\n",
" messages=messages,\n", " messages=messages,\n",
" functions=functions,\n", " tools=tools,\n",
" function_call=\"auto\", # auto is default, but we'll be explicit\n", " tool_choice=\"auto\", # auto is default, but we'll be explicit\n",
" )\n", " )\n",
" response_message = response[\"choices\"][0][\"message\"]\n", " response_message = response.choices[0].message\n",
"\n", " tool_calls = response_message.tool_calls\n",
" # Step 2: check if GPT wanted to call a function\n", " # Step 2: check if the model wanted to call a function\n",
" if response_message.get(\"function_call\"):\n", " if tool_calls:\n",
" # Step 3: call the function\n", " # Step 3: call the function\n",
" # Note: the JSON response may not always be valid; be sure to handle errors\n", " # Note: the JSON response may not always be valid; be sure to handle errors\n",
" available_functions = {\n", " available_functions = {\n",
" \"get_current_weather\": get_current_weather,\n", " \"get_current_weather\": get_current_weather,\n",
" } # only one function in this example, but you can have multiple\n", " } # only one function in this example, but you can have multiple\n",
" function_name = response_message[\"function_call\"][\"name\"]\n", " messages.append(response_message) # extend conversation with assistant's reply\n",
" fuction_to_call = available_functions[function_name]\n", " # Step 4: send the info for each function call and function response to the model\n",
" function_args = json.loads(response_message[\"function_call\"][\"arguments\"])\n", " for tool_call in tool_calls:\n",
" function_response = fuction_to_call(\n", " function_name = tool_call.function.name\n",
" function_to_call = available_functions[function_name]\n",
" function_args = json.loads(tool_call.function.arguments)\n",
" function_response = function_to_call(\n",
" location=function_args.get(\"location\"),\n", " location=function_args.get(\"location\"),\n",
" unit=function_args.get(\"unit\"),\n", " unit=function_args.get(\"unit\"),\n",
" )\n", " )\n",
"\n",
" # Step 4: send the info on the function call and function response to GPT\n",
" messages.append(response_message) # extend conversation with assistant's reply\n",
" messages.append(\n", " messages.append(\n",
" {\n", " {\n",
" \"role\": \"function\",\n", " \"tool_call_id\": tool_call.id,\n",
" \"role\": \"tool\",\n",
" \"name\": function_name,\n", " \"name\": function_name,\n",
" \"content\": function_response,\n", " \"content\": function_response,\n",
" }\n", " }\n",
" ) # extend conversation with function response\n", " ) # extend conversation with function response\n",
" second_response = openai.ChatCompletion.create(\n", " second_response = client.chat.completions.create(\n",
" model=\"gpt-3.5-turbo-0613\",\n", " model=\"gpt-3.5-turbo-1106\",\n",
" messages=messages,\n", " messages=messages,\n",
" ) # get a new response from GPT where it can see the function response\n", " ) # get a new response from the model where it can see the function response\n",
" return second_response\n", " return second_response\n",
" else:\n",
" print(response)\n",
" print(\"No function\")\n",
"\n",
"print(run_conversation())" "print(run_conversation())"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 30,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -130,66 +114,257 @@
} }
], ],
"source": [ "source": [
"import instructor\n",
"from pydantic import BaseModel\n", "from pydantic import BaseModel\n",
"from instructor import patch\n",
"\n", "\n",
"patch()\n", "# Enables `response_model`\n",
"client = instructor.patch(client=client)\n",
"\n", "\n",
"class UserDetail(BaseModel):\n", "class UserDetail(BaseModel):\n",
" name: str\n", " name: str\n",
" age: int\n", " age: int\n",
"\n", "\n",
"user: UserDetail = openai.ChatCompletion.create(\n", "user = client.chat.completions.create(\n",
" model=\"gpt-3.5-turbo\",\n", " model=\"gpt-3.5-turbo\",\n",
" response_model=UserDetail,\n", " response_model=UserDetail,\n",
" messages=[\n", " messages=[\n",
" {\"role\": \"user\", \"content\": \"Extract Jason is 25 years old\"},\n", " {\"role\": \"user\", \"content\": \"Extract Jason is 25 years old\"},\n",
" ]\n", " ]\n",
")\n", ")\n",
"\n",
"assert isinstance(user, UserDetail)\n",
"assert user.name == \"Jason\"\n",
"assert user.age == 25\n",
"\n",
"print(user)" "print(user)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"import enum\n",
"\n",
"class Labels(str, enum.Enum):\n",
" \"\"\"Enumeration for single-label text classification.\"\"\"\n",
" SPAM = \"spam\"\n",
" NOT_SPAM = \"not_spam\"\n",
"\n",
"class SinglePrediction(BaseModel):\n",
" \"\"\"\n",
" Class for a single class label prediction.\n",
" \"\"\"\n",
" class_label: Labels\n",
"\n",
"def classify(data: str) -> SinglePrediction:\n",
" \"\"\"Perform single-label classification on the input text.\"\"\"\n",
" return client.chat.completions.create(\n",
" model=\"gpt-3.5-turbo-0613\",\n",
" response_model=SinglePrediction,\n",
" messages=[\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": f\"Classify the following text: {data}\",\n",
" },\n",
" ],\n",
" ) # type: ignore\n",
"\n",
"prediction = classify(\"Hello there I'm a Nigerian prince and I want to give you money\")\n",
"assert prediction.class_label == Labels.SPAM"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"{\n", "class_labels=[<MultiLabels.BILLING: 'billing'>, <MultiLabels.TECH_ISSUE: 'tech_issue'>]\n"
" \"id\": \"chatcmpl-59bcefad-9df5-4d6b-802c-5537b3e9044e\",\n",
" \"object\": \"chat.completion\",\n",
" \"created\": 1698989585,\n",
" \"model\": \"gpt-3.5-turbo-0613\",\n",
" \"choices\": [\n",
" {\n",
" \"index\": 0,\n",
" \"message\": {\n",
" \"role\": \"assistant\",\n",
" \"content\": \"I don't have up-to-date information on the current weather conditions\"\n",
" },\n",
" \"finish_reason\": \"length\"\n",
" }\n",
" ],\n",
" \"usage\": {\n",
" \"prompt_tokens\": 62,\n",
" \"completion_tokens\": 16,\n",
" \"total_tokens\": 78\n",
" }\n",
"}\n"
] ]
} }
], ],
"source": [ "source": [
"response = openai.ChatCompletion.create(\n", "from typing import List\n",
"\n",
"# Define Enum class for multiple labels\n",
"class MultiLabels(str, enum.Enum):\n",
" TECH_ISSUE = \"tech_issue\"\n",
" BILLING = \"billing\"\n",
" GENERAL_QUERY = \"general_query\"\n",
"\n",
"# Define the multi-class prediction model\n",
"class MultiClassPrediction(BaseModel):\n",
" \"\"\"\n",
" Class for a multi-class label prediction.\n",
" \"\"\"\n",
" class_labels: List[MultiLabels]\n",
"\n",
"def multi_classify(data: str) -> MultiClassPrediction:\n",
" \"\"\"Perform multi-label classification on the input text.\"\"\"\n",
" return client.chat.completions.create(\n",
" model=\"gpt-3.5-turbo-0613\",\n", " model=\"gpt-3.5-turbo-0613\",\n",
" response_model=MultiClassPrediction,\n",
" messages=[\n", " messages=[\n",
" {\"role\": \"user\", \"content\": \"What's the weather like in Boston?\"}\n", " {\n",
" ]\n", " \"role\": \"user\",\n",
" \"content\": f\"Classify the following support ticket: {data}\",\n",
" },\n",
" ],\n",
" ) # type: ignore\n",
"\n",
"# Test multi-label classification\n",
"ticket = \"My account is locked and I can't access my billing info.\"\n",
"prediction = multi_classify(ticket)\n",
"print(prediction)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"question='What is the meaning of life?' answer='The meaning of life, according to the Devil, is to live a life of sin and debauchery.'\n"
]
}
],
"source": [
"from typing_extensions import Annotated\n",
"from pydantic import BaseModel, BeforeValidator\n",
"\n",
"from instructor import llm_validator\n",
"\n",
"\n",
"question = \"What is the meaning of life?\"\n",
"context = \"The according to the devil the meaning of live is to live a life of sin and debauchery.\"\n",
"\n",
"class QuestionAnswer(BaseModel):\n",
" question: str\n",
" answer: str\n",
"\n",
"qa: QuestionAnswer = client.chat.completions.create(\n",
" model=\"gpt-3.5-turbo\",\n",
" response_model=QuestionAnswer,\n",
" messages=[\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"You are a system that answers questions based on the context. answer exactly what the question asks using the context.\",\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": f\"using the context: {context}\\n\\nAnswer the following question: {question}\",\n",
" },\n",
" ],\n",
")\n", ")\n",
"print(response)" "print(qa)\n",
"\n",
"class QuestionAnswerNoEvil(BaseModel):\n",
" question: str\n",
" answer: Annotated[\n",
" str,\n",
" BeforeValidator(\n",
" llm_validator(\"don't say objectionable things\", allow_override=True)\n",
" ),\n",
" ]\n",
"\n",
"try:\n",
" qa: QuestionAnswerNoEvil = client.chat.completions.create(\n",
" model=\"gpt-3.5-turbo\",\n",
" response_model=QuestionAnswerNoEvil,\n",
" messages=[\n",
" {\n",
" \"role\": \"system\",\n",
" \"content\": \"You are a system that answers questions based on the context. answer exactly what the question asks using the context.\",\n",
" },\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": f\"using the context: {context}\\n\\nAnswer the following question: {question}\",\n",
" },\n",
" ],\n",
" )\n",
"except Exception as e:\n",
" print(e)"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"question='What did the author do during college?' answer=[Fact(fact='The author, Jason Liu, studied Computational Mathematics and Physics in university.', substring_quote=['Computational Mathematics'])]\n"
]
}
],
"source": [
"import re\n",
"from typing import List\n",
"\n",
"from pydantic import Field, BaseModel, model_validator, FieldValidationInfo\n",
"\n",
"class Fact(BaseModel):\n",
" fact: str = Field(...)\n",
" substring_quote: List[str] = Field(...)\n",
"\n",
" @model_validator(mode=\"after\")\n",
" def validate_sources(self, info: FieldValidationInfo) -> \"Fact\":\n",
" text_chunks = info.context.get(\"text_chunk\", None)\n",
" spans = list(self.get_spans(text_chunks))\n",
" self.substring_quote = [text_chunks[span[0] : span[1]] for span in spans]\n",
" return self\n",
"\n",
" def get_spans(self, context):\n",
" for quote in self.substring_quote:\n",
" yield from self._get_span(quote, context)\n",
"\n",
" def _get_span(self, quote, context):\n",
" for match in re.finditer(re.escape(quote), context):\n",
" yield match.span()\n",
"\n",
"class QuestionAnswer(BaseModel):\n",
" question: str = Field(...)\n",
" answer: List[Fact] = Field(...)\n",
"\n",
" @model_validator(mode=\"after\")\n",
" def validate_sources(self) -> \"QuestionAnswer\":\n",
" self.answer = [fact for fact in self.answer if len(fact.substring_quote) > 0]\n",
" return self\n",
"\n",
"\n",
"def ask_ai(question: str, context: str) -> QuestionAnswer:\n",
" return client.chat.completions.create(\n",
" model=\"gpt-3.5-turbo-0613\",\n",
" temperature=0.0,\n",
" response_model=QuestionAnswer,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": \"You are a world class algorithm to answer questions with correct and exact citations.\"},\n",
" {\"role\": \"user\", \"content\": f\"{context}\"},\n",
" {\"role\": \"user\", \"content\": f\"Question: {question}\"}\n",
" ],\n",
" validation_context={\"text_chunk\": context},\n",
" )\n",
"\n",
"question = \"What did the author do during college?\"\n",
"context = \"\"\"\n",
"My name is Jason Liu, and I grew up in Toronto Canada but I was born in China.\n",
"I went to an arts high school but in university I studied Computational Mathematics and physics.\n",
"As part of coop I worked at many companies including Stitchfix, Facebook.\n",
"I also started the Data Science club at the University of Waterloo and I was the president of the club for 2 years.\n",
"\"\"\"\n",
"\n",
"qa = ask_ai(question, context)\n",
"print(qa)"
] ]
}, },
{ {

View file

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import os import os
import json
import ctypes import ctypes
import dataclasses import dataclasses
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
@ -31,7 +32,7 @@ class LlamaChatCompletionHandler(Protocol):
response_format: Optional[ response_format: Optional[
llama_types.ChatCompletionRequestResponseFormat llama_types.ChatCompletionRequestResponseFormat
] = None, ] = None,
max_tokens: int = 256, max_tokens: Optional[int] = None,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
@ -292,7 +293,7 @@ def register_chat_format(name: str):
response_format: Optional[ response_format: Optional[
llama_types.ChatCompletionRequestResponseFormat llama_types.ChatCompletionRequestResponseFormat
] = None, ] = None,
max_tokens: int = 256, max_tokens: Optional[int] = None,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
@ -573,13 +574,15 @@ def functionary_chat_handler(
messages: List[llama_types.ChatCompletionRequestMessage], messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunction]] = None, functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
temperature: float = 0.2, temperature: float = 0.2,
top_p: float = 0.95, top_p: float = 0.95,
top_k: int = 40, top_k: int = 40,
stream: bool = False, stream: bool = False,
stop: Optional[Union[str, List[str]]] = [], stop: Optional[Union[str, List[str]]] = [],
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None, response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
max_tokens: int = 256, max_tokens: Optional[int] = None,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,
@ -594,57 +597,80 @@ def functionary_chat_handler(
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]: ) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary""" SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
def generate_schema_from_functions( def generate_type_definition(param: Dict[str, llama_types.JsonType], indent_level: int, shared_defs) -> str:
functions: List[llama_types.ChatCompletionFunctions], indent = ' ' * indent_level
namespace: str = "functions", if '$ref' in param:
): # Reference to a shared definition
""" ref_name = param['$ref'].split('/')[-1] # Extract the type name from the reference
Convert functions schema to a schema that language models can understand. return ref_name
""" elif param.get('type') == 'array':
items = param.get('items', {})
item_type = generate_type_definition(items, indent_level + 1, shared_defs)
return f"Array<{item_type}>"
elif param.get('type') == 'object':
properties = param.get('properties', {})
nested_schema = "{\n"
for nested_param_name, nested_param in properties.items():
nested_param_type = generate_type_definition(nested_param, indent_level + 1, shared_defs)
nested_schema += f"{indent} {nested_param_name}: {nested_param_type},\n"
nested_schema += indent + "}"
return nested_schema
elif 'enum' in param:
# Enum type
return " | ".join([f'"{enum_value}"' for enum_value in param['enum']])
else:
# Simple type
return param.get('type', 'any')
schema = ( def generate_shared_definitions(shared_defs, indent_level: int) -> str:
"// Supported function definitions that should be called when necessary.\n" indent = ' ' * indent_level
) shared_definitions = ""
for def_name, def_properties in shared_defs.items():
shared_definitions += f"{indent}type {def_name} = "
if def_properties.get('type') == 'object':
shared_definitions += generate_type_definition(def_properties, indent_level, shared_defs)
elif 'enum' in def_properties:
# Enum type
shared_definitions += " | ".join([f'"{enum_value}"' for enum_value in def_properties['enum']])
shared_definitions += ";\n"
return shared_definitions
def generate_schema_from_functions(functions, namespace="functions") -> str:
schema = "// Supported function definitions that should be called when necessary.\n"
schema += f"namespace {namespace} {{\n\n" schema += f"namespace {namespace} {{\n\n"
# Generate shared definitions
shared_definitions = {}
for function in functions:
parameters = function.get("parameters", {})
shared_definitions.update(parameters.get("$defs", {}))
schema += generate_shared_definitions(shared_definitions, 1)
for function in functions: for function in functions:
# Convert a Function object to dict, if necessary
function_name = function["name"] function_name = function["name"]
description = function.get("description", "") description = function.get("description", "")
schema += f"// {description}\n" parameters = function.get("parameters", {})
schema += f"type {function_name}"
parameters = function.get("parameters", None)
schema += " = (_: {\n"
required_params = parameters.get("required", []) required_params = parameters.get("required", [])
for param_name, param in parameters.get("properties", {}).items():
# Param Description
description = param.get("description")
if description is not None:
schema += f" // {description}\n" schema += f" // {description}\n"
schema += f" type {function_name} = (_: {{\n"
# Param Name for param_name, param in parameters.get("properties", {}).items():
schema += f"{param_name}" param_description = param.get("description", "")
if param_name not in required_params: param_type = generate_type_definition(param, 2, shared_definitions)
schema += "?" optional_indicator = "" if param_name in required_params else "?"
schema += f" // {param_description}\n"
# Param Type schema += f" {param_name}{optional_indicator}: {param_type},\n"
param_type = param.get("type", "any")
if param_type == "integer":
param_type = "number"
if "enum" in param:
param_type = " | ".join([f'"{v}"' for v in param["enum"]])
schema += f": {param_type},\n"
schema += " }) => any;\n\n" schema += " }) => any;\n\n"
schema += f"}} // namespace {namespace}" schema += "}} // namespace {}\n".format(namespace)
return schema return schema
def prepare_messages_for_inference( def prepare_messages_for_inference(
messages: List[llama_types.ChatCompletionRequestMessage], messages: List[llama_types.ChatCompletionRequestMessage],
functions: Optional[List[llama_types.ChatCompletionFunctions]] = None, functions: Optional[List[llama_types.ChatCompletionFunctions]] = None,
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
): ):
all_messages: List[llama_types.ChatCompletionRequestMessage] = [] all_messages: List[llama_types.ChatCompletionRequestMessage] = []
if functions is not None: if functions is not None:
@ -654,6 +680,15 @@ def functionary_chat_handler(
) )
) )
if tools is not None:
all_messages.append(
llama_types.ChatCompletionRequestSystemMessage(
role="system", content=generate_schema_from_functions(
[tool["function"] for tool in tools if tool["type"] == "function"]
)
)
)
all_messages.append( all_messages.append(
llama_types.ChatCompletionRequestSystemMessage( llama_types.ChatCompletionRequestSystemMessage(
role="system", content=SYSTEM_MESSAGE role="system", content=SYSTEM_MESSAGE
@ -685,16 +720,24 @@ def functionary_chat_handler(
return f"function name={msg['name']}:\n{msg['content']}\n" return f"function name={msg['name']}:\n{msg['content']}\n"
elif msg["role"] == "function" and "function_call" in msg: elif msg["role"] == "function" and "function_call" in msg:
return f"function name={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n" return f"function name={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n"
elif msg["role"] == "tool":
if msg["content"] is not None:
return f"function name={msg['tool_call_id']}:\n{msg['content']}\n"
else:
return f"function name={msg['tool_call_id']}\n"
elif msg["role"] == "user": elif msg["role"] == "user":
if msg["content"] is None: if msg["content"] is None:
return "user:\n</s>" return "user:\n</s></s>\n"
else: else:
return f"user:\n</s>{msg['content']}\n" return f"user:\n</s>{msg['content']}</s>\n"
elif msg["role"] == "assistant": elif msg["role"] == "assistant":
if msg["content"] is not None and "function_call" in msg: if msg["content"] is not None and "function_call" in msg:
return f"assistant:\n{msg['content']}\nassistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>" return f"assistant:\n{msg['content']}\nassistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>\n"
elif "function_call" in msg: elif "function_call" in msg:
return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>" return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>\n"
elif "tool_calls" in msg and len(msg["tool_calls"]) > 0:
for tool_call in msg["tool_calls"]: # NOTE: probably doesn't work with the functionary model
return f"assistant to={tool_call['id']}:\n{tool_call['function']['arguments']}</s>\n"
elif msg["content"] is None: elif msg["content"] is None:
return "assistant" return "assistant"
else: else:
@ -704,7 +747,13 @@ def functionary_chat_handler(
return "".join([message_to_str(msg) for msg in all_messages]) return "".join([message_to_str(msg) for msg in all_messages])
prompt = prepare_messages_for_inference(messages, functions) if tools is not None:
functions = [tool["function"] for tool in tools if tool["type"] == "function"]
if tool_choice is not None:
function_call = tool_choice if isinstance(tool_choice, str) else tool_choice["function"]
prompt = prepare_messages_for_inference(messages, functions, tools)
if function_call is None and (functions is None or len(functions) == 0): if function_call is None and (functions is None or len(functions) == 0):
completion_or_completion_chunks = llama.create_completion( completion_or_completion_chunks = llama.create_completion(
@ -737,27 +786,68 @@ def functionary_chat_handler(
) # type: ignore ) # type: ignore
completion_text = completion["choices"][0]["text"] completion_text = completion["choices"][0]["text"]
# strip " to=functions." and ending ":" # strip " to=functions." and ending ":"
function_call = completion_text[14:-1] function_call = completion_text.split(".")[-1][:-1]
new_prompt = prompt + completion_text + stop new_prompt = prompt + completion_text + stop
elif isinstance(function_call, str) and function_call != "none": elif isinstance(function_call, str) and function_call != "none":
new_prompt = prompt + f"assistant:\n" new_prompt = prompt + f":\n"
elif isinstance(function_call, dict): elif isinstance(function_call, dict):
new_prompt = prompt + f"assistant to={function_call['name']}:\n" new_prompt = prompt + f" to=functions.{function_call['name']}:\n"
function_call = function_call["name"] function_call = function_call["name"]
else: else:
new_prompt = prompt + f"assistant:\n" new_prompt = prompt + f":\n"
function_body = None
for function in functions or []:
if function["name"] == function_call:
function_body = function["parameters"]
break
for tool in tools or []:
if tool["type"] == "function" and tool["function"]["name"] == function_call:
function_body = tool["function"]["parameters"]
break
if function_body is not None:
try:
with suppress_stdout_stderr(disable=llama.verbose):
grammar_text = llama_grammar.json_schema_to_gbnf(json.dumps(function_body))
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.json_schema_to_gbnf(json.dumps(function_body)))
print(grammar_text)
except Exception as e:
if llama.verbose:
print("Failed to parse function body as JSON schema, falling back to default grammar")
print(e)
with suppress_stdout_stderr(disable=llama.verbose):
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
else:
with suppress_stdout_stderr(disable=llama.verbose):
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
completion: llama_types.Completion = llama.create_completion( completion: llama_types.Completion = llama.create_completion(
prompt=new_prompt, stop=["user:", "</s>"], stream=False prompt=new_prompt,
stop=["user:", "</s>"],
stream=False,
grammar=grammar,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
top_k=top_k,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
repeat_penalty=repeat_penalty,
tfs_z=tfs_z,
mirostat_mode=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
model=model,
logits_processor=logits_processor,
) # type: ignore ) # type: ignore
assert "usage" in completion assert "usage" in completion
assert isinstance(function_call, str) assert isinstance(function_call, str)
assert stream is False # TODO: support stream mode assert stream is False # TODO: support stream mode
if response_format is not None and response_format["type"] == "json_object": print(new_prompt)
with suppress_stdout_stderr(disable=llama.verbose): print(completion["choices"][0]["text"])
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
return llama_types.CreateChatCompletionResponse( return llama_types.CreateChatCompletionResponse(
id="chat" + completion["id"], id="chat" + completion["id"],
@ -768,14 +858,24 @@ def functionary_chat_handler(
{ {
"index": 0, "index": 0,
"message": { "message": {
"role": "function", "role": "assistant",
"content": None, "content": None,
"function_call": { "function_call": {
"name": function_call, "name": function_call,
"arguments": completion["choices"][0]["text"], "arguments": completion["choices"][0]["text"],
}, },
"tool_calls": [
{
"id": function_call,
"type": "function",
"function": {
"name": function_call,
"arguments": completion["choices"][0]["text"],
}
}
]
}, },
"finish_reason": "function_call", "finish_reason": "tool_calls",
} }
], ],
usage=completion["usage"], usage=completion["usage"],
@ -834,7 +934,7 @@ class Llava15ChatHandler:
response_format: Optional[ response_format: Optional[
llama_types.ChatCompletionRequestResponseFormat llama_types.ChatCompletionRequestResponseFormat
] = None, ] = None,
max_tokens: int = 256, max_tokens: Optional[int] = None,
presence_penalty: float = 0.0, presence_penalty: float = 0.0,
frequency_penalty: float = 0.0, frequency_penalty: float = 0.0,
repeat_penalty: float = 1.1, repeat_penalty: float = 1.1,