diff --git a/examples/notebooks/Functions.ipynb b/examples/notebooks/Functions.ipynb index 4d27bb0..81d58f6 100644 --- a/examples/notebooks/Functions.ipynb +++ b/examples/notebooks/Functions.ipynb @@ -2,34 +2,14 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 29, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{\n", - " \"id\": \"chatcmpl-a6db1bbb-a128-4c28-88fe-30717ec806b2\",\n", - " \"object\": \"chat.completion\",\n", - " \"created\": 1698989577,\n", - " \"model\": \"gpt-3.5-turbo-0613\",\n", - " \"choices\": [\n", - " {\n", - " \"index\": 0,\n", - " \"message\": {\n", - " \"role\": \"assistant\",\n", - " \"content\": \"The current weather in Boston is sunny with a temperature of 72 degrees\"\n", - " },\n", - " \"finish_reason\": \"length\"\n", - " }\n", - " ],\n", - " \"usage\": {\n", - " \"prompt_tokens\": 135,\n", - " \"completion_tokens\": 16,\n", - " \"total_tokens\": 151\n", - " }\n", - "}\n" + "ChatCompletion(id='chatcmpl-b6dcbb47-1120-4761-8cd9-83542c97647b', choices=[Choice(finish_reason='stop', index=0, message=ChatCompletionMessage(content=\"The current temperature in San Francisco is 72 degrees Fahrenheit. It's a sunny day with clear skies, making it perfect for outdoor activities.\\n \", role='assistant', function_call=None, tool_calls=None))], created=1699602158, model='gpt-3.5-turbo-1106', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=38, prompt_tokens=135, total_tokens=173))\n" ] } ], @@ -37,88 +17,92 @@ "import openai\n", "import json\n", "\n", - "openai.api_key = \"sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\" # can be anything\n", - "openai.api_base = \"http://100.64.159.73:8000/v1\"\n", + "\n", + "client = openai.OpenAI(\n", + " api_key = \"sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\", # can be anything\n", + " base_url = \"http://100.64.159.73:8000/v1\"\n", + ")\n", "\n", "# Example dummy function hard coded to return the same weather\n", "# In production, this could be your backend API or an external API\n", "def get_current_weather(location, unit=\"fahrenheit\"):\n", " \"\"\"Get the current weather in a given location\"\"\"\n", - " weather_info = {\n", - " \"location\": location,\n", - " \"temperature\": \"72\",\n", - " \"unit\": unit,\n", - " \"forecast\": [\"sunny\", \"windy\"],\n", - " }\n", - " return json.dumps(weather_info)\n", + " if \"tokyo\" in location.lower():\n", + " return json.dumps({\"location\": \"Tokyo\", \"temperature\": \"10\", \"unit\": \"celsius\"})\n", + " elif \"san francisco\" in location.lower():\n", + " return json.dumps({\"location\": \"San Francisco\", \"temperature\": \"72\", \"unit\": \"fahrenheit\"})\n", + " elif \"paris\" in location.lower():\n", + " return json.dumps({\"location\": \"Paris\", \"temperature\": \"22\", \"unit\": \"celsius\"})\n", + " else:\n", + " return json.dumps({\"location\": location, \"temperature\": \"unknown\"})\n", "\n", "def run_conversation():\n", - " # Step 1: send the conversation and available functions to GPT\n", - " messages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston?\"}]\n", - " functions = [\n", + " # Step 1: send the conversation and available functions to the model\n", + " messages = [{\"role\": \"user\", \"content\": \"What's the weather like in San Francisco, Tokyo, and Paris?\"}]\n", + " tools = [\n", " {\n", - " \"name\": \"get_current_weather\",\n", - " \"description\": \"Get the current weather in a given location\",\n", - " \"parameters\": {\n", - " \"type\": \"object\",\n", - " \"properties\": {\n", - " \"location\": {\n", - " \"type\": \"string\",\n", - " \"description\": \"The city and state, e.g. San Francisco, CA\",\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"get_current_weather\",\n", + " \"description\": \"Get the current weather in a given location\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"location\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The city and state, e.g. San Francisco, CA\",\n", + " },\n", + " \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n", " },\n", - " \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n", + " \"required\": [\"location\"],\n", " },\n", - " \"required\": [\"location\"],\n", " },\n", " }\n", " ]\n", - " response = openai.ChatCompletion.create(\n", - " model=\"gpt-3.5-turbo-0613\",\n", + " response = client.chat.completions.create(\n", + " model=\"gpt-3.5-turbo-1106\",\n", " messages=messages,\n", - " functions=functions,\n", - " function_call=\"auto\", # auto is default, but we'll be explicit\n", + " tools=tools,\n", + " tool_choice=\"auto\", # auto is default, but we'll be explicit\n", " )\n", - " response_message = response[\"choices\"][0][\"message\"]\n", - "\n", - " # Step 2: check if GPT wanted to call a function\n", - " if response_message.get(\"function_call\"):\n", + " response_message = response.choices[0].message\n", + " tool_calls = response_message.tool_calls\n", + " # Step 2: check if the model wanted to call a function\n", + " if tool_calls:\n", " # Step 3: call the function\n", " # Note: the JSON response may not always be valid; be sure to handle errors\n", " available_functions = {\n", " \"get_current_weather\": get_current_weather,\n", " } # only one function in this example, but you can have multiple\n", - " function_name = response_message[\"function_call\"][\"name\"]\n", - " fuction_to_call = available_functions[function_name]\n", - " function_args = json.loads(response_message[\"function_call\"][\"arguments\"])\n", - " function_response = fuction_to_call(\n", - " location=function_args.get(\"location\"),\n", - " unit=function_args.get(\"unit\"),\n", - " )\n", - "\n", - " # Step 4: send the info on the function call and function response to GPT\n", " messages.append(response_message) # extend conversation with assistant's reply\n", - " messages.append(\n", - " {\n", - " \"role\": \"function\",\n", - " \"name\": function_name,\n", - " \"content\": function_response,\n", - " }\n", - " ) # extend conversation with function response\n", - " second_response = openai.ChatCompletion.create(\n", - " model=\"gpt-3.5-turbo-0613\",\n", + " # Step 4: send the info for each function call and function response to the model\n", + " for tool_call in tool_calls:\n", + " function_name = tool_call.function.name\n", + " function_to_call = available_functions[function_name]\n", + " function_args = json.loads(tool_call.function.arguments)\n", + " function_response = function_to_call(\n", + " location=function_args.get(\"location\"),\n", + " unit=function_args.get(\"unit\"),\n", + " )\n", + " messages.append(\n", + " {\n", + " \"tool_call_id\": tool_call.id,\n", + " \"role\": \"tool\",\n", + " \"name\": function_name,\n", + " \"content\": function_response,\n", + " }\n", + " ) # extend conversation with function response\n", + " second_response = client.chat.completions.create(\n", + " model=\"gpt-3.5-turbo-1106\",\n", " messages=messages,\n", - " ) # get a new response from GPT where it can see the function response\n", + " ) # get a new response from the model where it can see the function response\n", " return second_response\n", - " else:\n", - " print(response)\n", - " print(\"No function\")\n", - "\n", "print(run_conversation())" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 30, "metadata": {}, "outputs": [ { @@ -130,66 +114,257 @@ } ], "source": [ + "import instructor\n", "from pydantic import BaseModel\n", - "from instructor import patch\n", "\n", - "patch()\n", + "# Enables `response_model`\n", + "client = instructor.patch(client=client)\n", "\n", "class UserDetail(BaseModel):\n", " name: str\n", " age: int\n", "\n", - "user: UserDetail = openai.ChatCompletion.create(\n", + "user = client.chat.completions.create(\n", " model=\"gpt-3.5-turbo\",\n", " response_model=UserDetail,\n", " messages=[\n", " {\"role\": \"user\", \"content\": \"Extract Jason is 25 years old\"},\n", " ]\n", ")\n", + "\n", + "assert isinstance(user, UserDetail)\n", + "assert user.name == \"Jason\"\n", + "assert user.age == 25\n", + "\n", "print(user)" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "import enum\n", + "\n", + "class Labels(str, enum.Enum):\n", + " \"\"\"Enumeration for single-label text classification.\"\"\"\n", + " SPAM = \"spam\"\n", + " NOT_SPAM = \"not_spam\"\n", + "\n", + "class SinglePrediction(BaseModel):\n", + " \"\"\"\n", + " Class for a single class label prediction.\n", + " \"\"\"\n", + " class_label: Labels\n", + "\n", + "def classify(data: str) -> SinglePrediction:\n", + " \"\"\"Perform single-label classification on the input text.\"\"\"\n", + " return client.chat.completions.create(\n", + " model=\"gpt-3.5-turbo-0613\",\n", + " response_model=SinglePrediction,\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": f\"Classify the following text: {data}\",\n", + " },\n", + " ],\n", + " ) # type: ignore\n", + "\n", + "prediction = classify(\"Hello there I'm a Nigerian prince and I want to give you money\")\n", + "assert prediction.class_label == Labels.SPAM" + ] + }, + { + "cell_type": "code", + "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "{\n", - " \"id\": \"chatcmpl-59bcefad-9df5-4d6b-802c-5537b3e9044e\",\n", - " \"object\": \"chat.completion\",\n", - " \"created\": 1698989585,\n", - " \"model\": \"gpt-3.5-turbo-0613\",\n", - " \"choices\": [\n", - " {\n", - " \"index\": 0,\n", - " \"message\": {\n", - " \"role\": \"assistant\",\n", - " \"content\": \"I don't have up-to-date information on the current weather conditions\"\n", - " },\n", - " \"finish_reason\": \"length\"\n", - " }\n", - " ],\n", - " \"usage\": {\n", - " \"prompt_tokens\": 62,\n", - " \"completion_tokens\": 16,\n", - " \"total_tokens\": 78\n", - " }\n", - "}\n" + "class_labels=[, ]\n" ] } ], "source": [ - "response = openai.ChatCompletion.create(\n", - " model=\"gpt-3.5-turbo-0613\",\n", + "from typing import List\n", + "\n", + "# Define Enum class for multiple labels\n", + "class MultiLabels(str, enum.Enum):\n", + " TECH_ISSUE = \"tech_issue\"\n", + " BILLING = \"billing\"\n", + " GENERAL_QUERY = \"general_query\"\n", + "\n", + "# Define the multi-class prediction model\n", + "class MultiClassPrediction(BaseModel):\n", + " \"\"\"\n", + " Class for a multi-class label prediction.\n", + " \"\"\"\n", + " class_labels: List[MultiLabels]\n", + "\n", + "def multi_classify(data: str) -> MultiClassPrediction:\n", + " \"\"\"Perform multi-label classification on the input text.\"\"\"\n", + " return client.chat.completions.create(\n", + " model=\"gpt-3.5-turbo-0613\",\n", + " response_model=MultiClassPrediction,\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": f\"Classify the following support ticket: {data}\",\n", + " },\n", + " ],\n", + " ) # type: ignore\n", + "\n", + "# Test multi-label classification\n", + "ticket = \"My account is locked and I can't access my billing info.\"\n", + "prediction = multi_classify(ticket)\n", + "print(prediction)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "question='What is the meaning of life?' answer='The meaning of life, according to the Devil, is to live a life of sin and debauchery.'\n" + ] + } + ], + "source": [ + "from typing_extensions import Annotated\n", + "from pydantic import BaseModel, BeforeValidator\n", + "\n", + "from instructor import llm_validator\n", + "\n", + "\n", + "question = \"What is the meaning of life?\"\n", + "context = \"The according to the devil the meaning of live is to live a life of sin and debauchery.\"\n", + "\n", + "class QuestionAnswer(BaseModel):\n", + " question: str\n", + " answer: str\n", + "\n", + "qa: QuestionAnswer = client.chat.completions.create(\n", + " model=\"gpt-3.5-turbo\",\n", + " response_model=QuestionAnswer,\n", " messages=[\n", - " {\"role\": \"user\", \"content\": \"What's the weather like in Boston?\"}\n", - " ]\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"You are a system that answers questions based on the context. answer exactly what the question asks using the context.\",\n", + " },\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": f\"using the context: {context}\\n\\nAnswer the following question: {question}\",\n", + " },\n", + " ],\n", ")\n", - "print(response)" + "print(qa)\n", + "\n", + "class QuestionAnswerNoEvil(BaseModel):\n", + " question: str\n", + " answer: Annotated[\n", + " str,\n", + " BeforeValidator(\n", + " llm_validator(\"don't say objectionable things\", allow_override=True)\n", + " ),\n", + " ]\n", + "\n", + "try:\n", + " qa: QuestionAnswerNoEvil = client.chat.completions.create(\n", + " model=\"gpt-3.5-turbo\",\n", + " response_model=QuestionAnswerNoEvil,\n", + " messages=[\n", + " {\n", + " \"role\": \"system\",\n", + " \"content\": \"You are a system that answers questions based on the context. answer exactly what the question asks using the context.\",\n", + " },\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": f\"using the context: {context}\\n\\nAnswer the following question: {question}\",\n", + " },\n", + " ],\n", + " )\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "question='What did the author do during college?' answer=[Fact(fact='The author, Jason Liu, studied Computational Mathematics and Physics in university.', substring_quote=['Computational Mathematics'])]\n" + ] + } + ], + "source": [ + "import re\n", + "from typing import List\n", + "\n", + "from pydantic import Field, BaseModel, model_validator, FieldValidationInfo\n", + "\n", + "class Fact(BaseModel):\n", + " fact: str = Field(...)\n", + " substring_quote: List[str] = Field(...)\n", + "\n", + " @model_validator(mode=\"after\")\n", + " def validate_sources(self, info: FieldValidationInfo) -> \"Fact\":\n", + " text_chunks = info.context.get(\"text_chunk\", None)\n", + " spans = list(self.get_spans(text_chunks))\n", + " self.substring_quote = [text_chunks[span[0] : span[1]] for span in spans]\n", + " return self\n", + "\n", + " def get_spans(self, context):\n", + " for quote in self.substring_quote:\n", + " yield from self._get_span(quote, context)\n", + "\n", + " def _get_span(self, quote, context):\n", + " for match in re.finditer(re.escape(quote), context):\n", + " yield match.span()\n", + "\n", + "class QuestionAnswer(BaseModel):\n", + " question: str = Field(...)\n", + " answer: List[Fact] = Field(...)\n", + "\n", + " @model_validator(mode=\"after\")\n", + " def validate_sources(self) -> \"QuestionAnswer\":\n", + " self.answer = [fact for fact in self.answer if len(fact.substring_quote) > 0]\n", + " return self\n", + "\n", + "\n", + "def ask_ai(question: str, context: str) -> QuestionAnswer:\n", + " return client.chat.completions.create(\n", + " model=\"gpt-3.5-turbo-0613\",\n", + " temperature=0.0,\n", + " response_model=QuestionAnswer,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": \"You are a world class algorithm to answer questions with correct and exact citations.\"},\n", + " {\"role\": \"user\", \"content\": f\"{context}\"},\n", + " {\"role\": \"user\", \"content\": f\"Question: {question}\"}\n", + " ],\n", + " validation_context={\"text_chunk\": context},\n", + " )\n", + "\n", + "question = \"What did the author do during college?\"\n", + "context = \"\"\"\n", + "My name is Jason Liu, and I grew up in Toronto Canada but I was born in China.\n", + "I went to an arts high school but in university I studied Computational Mathematics and physics.\n", + "As part of coop I worked at many companies including Stitchfix, Facebook.\n", + "I also started the Data Science club at the University of Waterloo and I was the president of the club for 2 years.\n", + "\"\"\"\n", + "\n", + "qa = ask_ai(question, context)\n", + "print(qa)" ] }, { diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 3df4ee3..0ab57f8 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import json import ctypes import dataclasses from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol @@ -31,7 +32,7 @@ class LlamaChatCompletionHandler(Protocol): response_format: Optional[ llama_types.ChatCompletionRequestResponseFormat ] = None, - max_tokens: int = 256, + max_tokens: Optional[int] = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.1, @@ -292,7 +293,7 @@ def register_chat_format(name: str): response_format: Optional[ llama_types.ChatCompletionRequestResponseFormat ] = None, - max_tokens: int = 256, + max_tokens: Optional[int] = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.1, @@ -573,13 +574,15 @@ def functionary_chat_handler( messages: List[llama_types.ChatCompletionRequestMessage], functions: Optional[List[llama_types.ChatCompletionFunction]] = None, function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, + tools: Optional[List[llama_types.ChatCompletionTool]] = None, + tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, temperature: float = 0.2, top_p: float = 0.95, top_k: int = 40, stream: bool = False, stop: Optional[Union[str, List[str]]] = [], response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None, - max_tokens: int = 256, + max_tokens: Optional[int] = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.1, @@ -594,57 +597,80 @@ def functionary_chat_handler( ) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]: SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary""" - def generate_schema_from_functions( - functions: List[llama_types.ChatCompletionFunctions], - namespace: str = "functions", - ): - """ - Convert functions schema to a schema that language models can understand. - """ + def generate_type_definition(param: Dict[str, llama_types.JsonType], indent_level: int, shared_defs) -> str: + indent = ' ' * indent_level + if '$ref' in param: + # Reference to a shared definition + ref_name = param['$ref'].split('/')[-1] # Extract the type name from the reference + return ref_name + elif param.get('type') == 'array': + items = param.get('items', {}) + item_type = generate_type_definition(items, indent_level + 1, shared_defs) + return f"Array<{item_type}>" + elif param.get('type') == 'object': + properties = param.get('properties', {}) + nested_schema = "{\n" + for nested_param_name, nested_param in properties.items(): + nested_param_type = generate_type_definition(nested_param, indent_level + 1, shared_defs) + nested_schema += f"{indent} {nested_param_name}: {nested_param_type},\n" + nested_schema += indent + "}" + return nested_schema + elif 'enum' in param: + # Enum type + return " | ".join([f'"{enum_value}"' for enum_value in param['enum']]) + else: + # Simple type + return param.get('type', 'any') - schema = ( - "// Supported function definitions that should be called when necessary.\n" - ) + def generate_shared_definitions(shared_defs, indent_level: int) -> str: + indent = ' ' * indent_level + shared_definitions = "" + for def_name, def_properties in shared_defs.items(): + shared_definitions += f"{indent}type {def_name} = " + if def_properties.get('type') == 'object': + shared_definitions += generate_type_definition(def_properties, indent_level, shared_defs) + elif 'enum' in def_properties: + # Enum type + shared_definitions += " | ".join([f'"{enum_value}"' for enum_value in def_properties['enum']]) + shared_definitions += ";\n" + return shared_definitions + + def generate_schema_from_functions(functions, namespace="functions") -> str: + schema = "// Supported function definitions that should be called when necessary.\n" schema += f"namespace {namespace} {{\n\n" + # Generate shared definitions + shared_definitions = {} + for function in functions: + parameters = function.get("parameters", {}) + shared_definitions.update(parameters.get("$defs", {})) + + schema += generate_shared_definitions(shared_definitions, 1) + for function in functions: - # Convert a Function object to dict, if necessary function_name = function["name"] description = function.get("description", "") - schema += f"// {description}\n" - schema += f"type {function_name}" - - parameters = function.get("parameters", None) - schema += " = (_: {\n" + parameters = function.get("parameters", {}) required_params = parameters.get("required", []) + + schema += f" // {description}\n" + schema += f" type {function_name} = (_: {{\n" + for param_name, param in parameters.get("properties", {}).items(): - # Param Description - description = param.get("description") - if description is not None: - schema += f"// {description}\n" - - # Param Name - schema += f"{param_name}" - if param_name not in required_params: - schema += "?" - - # Param Type - param_type = param.get("type", "any") - if param_type == "integer": - param_type = "number" - if "enum" in param: - param_type = " | ".join([f'"{v}"' for v in param["enum"]]) - schema += f": {param_type},\n" - - schema += "}) => any;\n\n" - - schema += f"}} // namespace {namespace}" + param_description = param.get("description", "") + param_type = generate_type_definition(param, 2, shared_definitions) + optional_indicator = "" if param_name in required_params else "?" + schema += f" // {param_description}\n" + schema += f" {param_name}{optional_indicator}: {param_type},\n" + schema += " }) => any;\n\n" + schema += "}} // namespace {}\n".format(namespace) return schema def prepare_messages_for_inference( messages: List[llama_types.ChatCompletionRequestMessage], functions: Optional[List[llama_types.ChatCompletionFunctions]] = None, + tools: Optional[List[llama_types.ChatCompletionTool]] = None, ): all_messages: List[llama_types.ChatCompletionRequestMessage] = [] if functions is not None: @@ -653,6 +679,15 @@ def functionary_chat_handler( role="system", content=generate_schema_from_functions(functions) ) ) + + if tools is not None: + all_messages.append( + llama_types.ChatCompletionRequestSystemMessage( + role="system", content=generate_schema_from_functions( + [tool["function"] for tool in tools if tool["type"] == "function"] + ) + ) + ) all_messages.append( llama_types.ChatCompletionRequestSystemMessage( @@ -685,16 +720,24 @@ def functionary_chat_handler( return f"function name={msg['name']}:\n{msg['content']}\n" elif msg["role"] == "function" and "function_call" in msg: return f"function name={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n" + elif msg["role"] == "tool": + if msg["content"] is not None: + return f"function name={msg['tool_call_id']}:\n{msg['content']}\n" + else: + return f"function name={msg['tool_call_id']}\n" elif msg["role"] == "user": if msg["content"] is None: - return "user:\n" + return "user:\n\n" else: - return f"user:\n{msg['content']}\n" + return f"user:\n{msg['content']}\n" elif msg["role"] == "assistant": if msg["content"] is not None and "function_call" in msg: - return f"assistant:\n{msg['content']}\nassistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}" + return f"assistant:\n{msg['content']}\nassistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n" elif "function_call" in msg: - return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}" + return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n" + elif "tool_calls" in msg and len(msg["tool_calls"]) > 0: + for tool_call in msg["tool_calls"]: # NOTE: probably doesn't work with the functionary model + return f"assistant to={tool_call['id']}:\n{tool_call['function']['arguments']}\n" elif msg["content"] is None: return "assistant" else: @@ -703,8 +746,14 @@ def functionary_chat_handler( raise ValueError(f"Unsupported role: {msg['role']}") return "".join([message_to_str(msg) for msg in all_messages]) + + if tools is not None: + functions = [tool["function"] for tool in tools if tool["type"] == "function"] + + if tool_choice is not None: + function_call = tool_choice if isinstance(tool_choice, str) else tool_choice["function"] - prompt = prepare_messages_for_inference(messages, functions) + prompt = prepare_messages_for_inference(messages, functions, tools) if function_call is None and (functions is None or len(functions) == 0): completion_or_completion_chunks = llama.create_completion( @@ -737,27 +786,68 @@ def functionary_chat_handler( ) # type: ignore completion_text = completion["choices"][0]["text"] # strip " to=functions." and ending ":" - function_call = completion_text[14:-1] + function_call = completion_text.split(".")[-1][:-1] new_prompt = prompt + completion_text + stop elif isinstance(function_call, str) and function_call != "none": - new_prompt = prompt + f"assistant:\n" + new_prompt = prompt + f":\n" elif isinstance(function_call, dict): - new_prompt = prompt + f"assistant to={function_call['name']}:\n" + new_prompt = prompt + f" to=functions.{function_call['name']}:\n" function_call = function_call["name"] else: - new_prompt = prompt + f"assistant:\n" + new_prompt = prompt + f":\n" + + function_body = None + for function in functions or []: + if function["name"] == function_call: + function_body = function["parameters"] + break + for tool in tools or []: + if tool["type"] == "function" and tool["function"]["name"] == function_call: + function_body = tool["function"]["parameters"] + break + + if function_body is not None: + try: + with suppress_stdout_stderr(disable=llama.verbose): + grammar_text = llama_grammar.json_schema_to_gbnf(json.dumps(function_body)) + grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.json_schema_to_gbnf(json.dumps(function_body))) + print(grammar_text) + except Exception as e: + if llama.verbose: + print("Failed to parse function body as JSON schema, falling back to default grammar") + print(e) + with suppress_stdout_stderr(disable=llama.verbose): + grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF) + else: + with suppress_stdout_stderr(disable=llama.verbose): + grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF) completion: llama_types.Completion = llama.create_completion( - prompt=new_prompt, stop=["user:", ""], stream=False + prompt=new_prompt, + stop=["user:", ""], + stream=False, + grammar=grammar, + max_tokens=max_tokens, + temperature=temperature, + top_p=top_p, + top_k=top_k, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repeat_penalty=repeat_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + model=model, + logits_processor=logits_processor, ) # type: ignore assert "usage" in completion assert isinstance(function_call, str) assert stream is False # TODO: support stream mode - if response_format is not None and response_format["type"] == "json_object": - with suppress_stdout_stderr(disable=llama.verbose): - grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF) + print(new_prompt) + print(completion["choices"][0]["text"]) return llama_types.CreateChatCompletionResponse( id="chat" + completion["id"], @@ -768,14 +858,24 @@ def functionary_chat_handler( { "index": 0, "message": { - "role": "function", + "role": "assistant", "content": None, "function_call": { "name": function_call, "arguments": completion["choices"][0]["text"], }, + "tool_calls": [ + { + "id": function_call, + "type": "function", + "function": { + "name": function_call, + "arguments": completion["choices"][0]["text"], + } + } + ] }, - "finish_reason": "function_call", + "finish_reason": "tool_calls", } ], usage=completion["usage"], @@ -834,7 +934,7 @@ class Llava15ChatHandler: response_format: Optional[ llama_types.ChatCompletionRequestResponseFormat ] = None, - max_tokens: int = 256, + max_tokens: Optional[int] = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, repeat_penalty: float = 1.1,