Update functionary for new OpenAI API
This commit is contained in:
parent
17da8fb446
commit
1b376c62b7
2 changed files with 437 additions and 162 deletions
|
@ -2,34 +2,14 @@
|
||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 29,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"{\n",
|
"ChatCompletion(id='chatcmpl-b6dcbb47-1120-4761-8cd9-83542c97647b', choices=[Choice(finish_reason='stop', index=0, message=ChatCompletionMessage(content=\"The current temperature in San Francisco is 72 degrees Fahrenheit. It's a sunny day with clear skies, making it perfect for outdoor activities.\\n \", role='assistant', function_call=None, tool_calls=None))], created=1699602158, model='gpt-3.5-turbo-1106', object='chat.completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=38, prompt_tokens=135, total_tokens=173))\n"
|
||||||
" \"id\": \"chatcmpl-a6db1bbb-a128-4c28-88fe-30717ec806b2\",\n",
|
|
||||||
" \"object\": \"chat.completion\",\n",
|
|
||||||
" \"created\": 1698989577,\n",
|
|
||||||
" \"model\": \"gpt-3.5-turbo-0613\",\n",
|
|
||||||
" \"choices\": [\n",
|
|
||||||
" {\n",
|
|
||||||
" \"index\": 0,\n",
|
|
||||||
" \"message\": {\n",
|
|
||||||
" \"role\": \"assistant\",\n",
|
|
||||||
" \"content\": \"The current weather in Boston is sunny with a temperature of 72 degrees\"\n",
|
|
||||||
" },\n",
|
|
||||||
" \"finish_reason\": \"length\"\n",
|
|
||||||
" }\n",
|
|
||||||
" ],\n",
|
|
||||||
" \"usage\": {\n",
|
|
||||||
" \"prompt_tokens\": 135,\n",
|
|
||||||
" \"completion_tokens\": 16,\n",
|
|
||||||
" \"total_tokens\": 151\n",
|
|
||||||
" }\n",
|
|
||||||
"}\n"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
@ -37,88 +17,92 @@
|
||||||
"import openai\n",
|
"import openai\n",
|
||||||
"import json\n",
|
"import json\n",
|
||||||
"\n",
|
"\n",
|
||||||
"openai.api_key = \"sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\" # can be anything\n",
|
"\n",
|
||||||
"openai.api_base = \"http://100.64.159.73:8000/v1\"\n",
|
"client = openai.OpenAI(\n",
|
||||||
|
" api_key = \"sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx\", # can be anything\n",
|
||||||
|
" base_url = \"http://100.64.159.73:8000/v1\"\n",
|
||||||
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Example dummy function hard coded to return the same weather\n",
|
"# Example dummy function hard coded to return the same weather\n",
|
||||||
"# In production, this could be your backend API or an external API\n",
|
"# In production, this could be your backend API or an external API\n",
|
||||||
"def get_current_weather(location, unit=\"fahrenheit\"):\n",
|
"def get_current_weather(location, unit=\"fahrenheit\"):\n",
|
||||||
" \"\"\"Get the current weather in a given location\"\"\"\n",
|
" \"\"\"Get the current weather in a given location\"\"\"\n",
|
||||||
" weather_info = {\n",
|
" if \"tokyo\" in location.lower():\n",
|
||||||
" \"location\": location,\n",
|
" return json.dumps({\"location\": \"Tokyo\", \"temperature\": \"10\", \"unit\": \"celsius\"})\n",
|
||||||
" \"temperature\": \"72\",\n",
|
" elif \"san francisco\" in location.lower():\n",
|
||||||
" \"unit\": unit,\n",
|
" return json.dumps({\"location\": \"San Francisco\", \"temperature\": \"72\", \"unit\": \"fahrenheit\"})\n",
|
||||||
" \"forecast\": [\"sunny\", \"windy\"],\n",
|
" elif \"paris\" in location.lower():\n",
|
||||||
" }\n",
|
" return json.dumps({\"location\": \"Paris\", \"temperature\": \"22\", \"unit\": \"celsius\"})\n",
|
||||||
" return json.dumps(weather_info)\n",
|
" else:\n",
|
||||||
|
" return json.dumps({\"location\": location, \"temperature\": \"unknown\"})\n",
|
||||||
"\n",
|
"\n",
|
||||||
"def run_conversation():\n",
|
"def run_conversation():\n",
|
||||||
" # Step 1: send the conversation and available functions to GPT\n",
|
" # Step 1: send the conversation and available functions to the model\n",
|
||||||
" messages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston?\"}]\n",
|
" messages = [{\"role\": \"user\", \"content\": \"What's the weather like in San Francisco, Tokyo, and Paris?\"}]\n",
|
||||||
" functions = [\n",
|
" tools = [\n",
|
||||||
" {\n",
|
" {\n",
|
||||||
" \"name\": \"get_current_weather\",\n",
|
" \"type\": \"function\",\n",
|
||||||
" \"description\": \"Get the current weather in a given location\",\n",
|
" \"function\": {\n",
|
||||||
" \"parameters\": {\n",
|
" \"name\": \"get_current_weather\",\n",
|
||||||
" \"type\": \"object\",\n",
|
" \"description\": \"Get the current weather in a given location\",\n",
|
||||||
" \"properties\": {\n",
|
" \"parameters\": {\n",
|
||||||
" \"location\": {\n",
|
" \"type\": \"object\",\n",
|
||||||
" \"type\": \"string\",\n",
|
" \"properties\": {\n",
|
||||||
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
|
" \"location\": {\n",
|
||||||
|
" \"type\": \"string\",\n",
|
||||||
|
" \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
|
||||||
|
" },\n",
|
||||||
|
" \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n",
|
" \"required\": [\"location\"],\n",
|
||||||
" },\n",
|
" },\n",
|
||||||
" \"required\": [\"location\"],\n",
|
|
||||||
" },\n",
|
" },\n",
|
||||||
" }\n",
|
" }\n",
|
||||||
" ]\n",
|
" ]\n",
|
||||||
" response = openai.ChatCompletion.create(\n",
|
" response = client.chat.completions.create(\n",
|
||||||
" model=\"gpt-3.5-turbo-0613\",\n",
|
" model=\"gpt-3.5-turbo-1106\",\n",
|
||||||
" messages=messages,\n",
|
" messages=messages,\n",
|
||||||
" functions=functions,\n",
|
" tools=tools,\n",
|
||||||
" function_call=\"auto\", # auto is default, but we'll be explicit\n",
|
" tool_choice=\"auto\", # auto is default, but we'll be explicit\n",
|
||||||
" )\n",
|
" )\n",
|
||||||
" response_message = response[\"choices\"][0][\"message\"]\n",
|
" response_message = response.choices[0].message\n",
|
||||||
"\n",
|
" tool_calls = response_message.tool_calls\n",
|
||||||
" # Step 2: check if GPT wanted to call a function\n",
|
" # Step 2: check if the model wanted to call a function\n",
|
||||||
" if response_message.get(\"function_call\"):\n",
|
" if tool_calls:\n",
|
||||||
" # Step 3: call the function\n",
|
" # Step 3: call the function\n",
|
||||||
" # Note: the JSON response may not always be valid; be sure to handle errors\n",
|
" # Note: the JSON response may not always be valid; be sure to handle errors\n",
|
||||||
" available_functions = {\n",
|
" available_functions = {\n",
|
||||||
" \"get_current_weather\": get_current_weather,\n",
|
" \"get_current_weather\": get_current_weather,\n",
|
||||||
" } # only one function in this example, but you can have multiple\n",
|
" } # only one function in this example, but you can have multiple\n",
|
||||||
" function_name = response_message[\"function_call\"][\"name\"]\n",
|
|
||||||
" fuction_to_call = available_functions[function_name]\n",
|
|
||||||
" function_args = json.loads(response_message[\"function_call\"][\"arguments\"])\n",
|
|
||||||
" function_response = fuction_to_call(\n",
|
|
||||||
" location=function_args.get(\"location\"),\n",
|
|
||||||
" unit=function_args.get(\"unit\"),\n",
|
|
||||||
" )\n",
|
|
||||||
"\n",
|
|
||||||
" # Step 4: send the info on the function call and function response to GPT\n",
|
|
||||||
" messages.append(response_message) # extend conversation with assistant's reply\n",
|
" messages.append(response_message) # extend conversation with assistant's reply\n",
|
||||||
" messages.append(\n",
|
" # Step 4: send the info for each function call and function response to the model\n",
|
||||||
" {\n",
|
" for tool_call in tool_calls:\n",
|
||||||
" \"role\": \"function\",\n",
|
" function_name = tool_call.function.name\n",
|
||||||
" \"name\": function_name,\n",
|
" function_to_call = available_functions[function_name]\n",
|
||||||
" \"content\": function_response,\n",
|
" function_args = json.loads(tool_call.function.arguments)\n",
|
||||||
" }\n",
|
" function_response = function_to_call(\n",
|
||||||
" ) # extend conversation with function response\n",
|
" location=function_args.get(\"location\"),\n",
|
||||||
" second_response = openai.ChatCompletion.create(\n",
|
" unit=function_args.get(\"unit\"),\n",
|
||||||
" model=\"gpt-3.5-turbo-0613\",\n",
|
" )\n",
|
||||||
|
" messages.append(\n",
|
||||||
|
" {\n",
|
||||||
|
" \"tool_call_id\": tool_call.id,\n",
|
||||||
|
" \"role\": \"tool\",\n",
|
||||||
|
" \"name\": function_name,\n",
|
||||||
|
" \"content\": function_response,\n",
|
||||||
|
" }\n",
|
||||||
|
" ) # extend conversation with function response\n",
|
||||||
|
" second_response = client.chat.completions.create(\n",
|
||||||
|
" model=\"gpt-3.5-turbo-1106\",\n",
|
||||||
" messages=messages,\n",
|
" messages=messages,\n",
|
||||||
" ) # get a new response from GPT where it can see the function response\n",
|
" ) # get a new response from the model where it can see the function response\n",
|
||||||
" return second_response\n",
|
" return second_response\n",
|
||||||
" else:\n",
|
|
||||||
" print(response)\n",
|
|
||||||
" print(\"No function\")\n",
|
|
||||||
"\n",
|
|
||||||
"print(run_conversation())"
|
"print(run_conversation())"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 30,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
|
@ -130,66 +114,257 @@
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
|
"import instructor\n",
|
||||||
"from pydantic import BaseModel\n",
|
"from pydantic import BaseModel\n",
|
||||||
"from instructor import patch\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"patch()\n",
|
"# Enables `response_model`\n",
|
||||||
|
"client = instructor.patch(client=client)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"class UserDetail(BaseModel):\n",
|
"class UserDetail(BaseModel):\n",
|
||||||
" name: str\n",
|
" name: str\n",
|
||||||
" age: int\n",
|
" age: int\n",
|
||||||
"\n",
|
"\n",
|
||||||
"user: UserDetail = openai.ChatCompletion.create(\n",
|
"user = client.chat.completions.create(\n",
|
||||||
" model=\"gpt-3.5-turbo\",\n",
|
" model=\"gpt-3.5-turbo\",\n",
|
||||||
" response_model=UserDetail,\n",
|
" response_model=UserDetail,\n",
|
||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" {\"role\": \"user\", \"content\": \"Extract Jason is 25 years old\"},\n",
|
" {\"role\": \"user\", \"content\": \"Extract Jason is 25 years old\"},\n",
|
||||||
" ]\n",
|
" ]\n",
|
||||||
")\n",
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"assert isinstance(user, UserDetail)\n",
|
||||||
|
"assert user.name == \"Jason\"\n",
|
||||||
|
"assert user.age == 25\n",
|
||||||
|
"\n",
|
||||||
"print(user)"
|
"print(user)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 31,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import enum\n",
|
||||||
|
"\n",
|
||||||
|
"class Labels(str, enum.Enum):\n",
|
||||||
|
" \"\"\"Enumeration for single-label text classification.\"\"\"\n",
|
||||||
|
" SPAM = \"spam\"\n",
|
||||||
|
" NOT_SPAM = \"not_spam\"\n",
|
||||||
|
"\n",
|
||||||
|
"class SinglePrediction(BaseModel):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Class for a single class label prediction.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" class_label: Labels\n",
|
||||||
|
"\n",
|
||||||
|
"def classify(data: str) -> SinglePrediction:\n",
|
||||||
|
" \"\"\"Perform single-label classification on the input text.\"\"\"\n",
|
||||||
|
" return client.chat.completions.create(\n",
|
||||||
|
" model=\"gpt-3.5-turbo-0613\",\n",
|
||||||
|
" response_model=SinglePrediction,\n",
|
||||||
|
" messages=[\n",
|
||||||
|
" {\n",
|
||||||
|
" \"role\": \"user\",\n",
|
||||||
|
" \"content\": f\"Classify the following text: {data}\",\n",
|
||||||
|
" },\n",
|
||||||
|
" ],\n",
|
||||||
|
" ) # type: ignore\n",
|
||||||
|
"\n",
|
||||||
|
"prediction = classify(\"Hello there I'm a Nigerian prince and I want to give you money\")\n",
|
||||||
|
"assert prediction.class_label == Labels.SPAM"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 32,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"{\n",
|
"class_labels=[<MultiLabels.BILLING: 'billing'>, <MultiLabels.TECH_ISSUE: 'tech_issue'>]\n"
|
||||||
" \"id\": \"chatcmpl-59bcefad-9df5-4d6b-802c-5537b3e9044e\",\n",
|
|
||||||
" \"object\": \"chat.completion\",\n",
|
|
||||||
" \"created\": 1698989585,\n",
|
|
||||||
" \"model\": \"gpt-3.5-turbo-0613\",\n",
|
|
||||||
" \"choices\": [\n",
|
|
||||||
" {\n",
|
|
||||||
" \"index\": 0,\n",
|
|
||||||
" \"message\": {\n",
|
|
||||||
" \"role\": \"assistant\",\n",
|
|
||||||
" \"content\": \"I don't have up-to-date information on the current weather conditions\"\n",
|
|
||||||
" },\n",
|
|
||||||
" \"finish_reason\": \"length\"\n",
|
|
||||||
" }\n",
|
|
||||||
" ],\n",
|
|
||||||
" \"usage\": {\n",
|
|
||||||
" \"prompt_tokens\": 62,\n",
|
|
||||||
" \"completion_tokens\": 16,\n",
|
|
||||||
" \"total_tokens\": 78\n",
|
|
||||||
" }\n",
|
|
||||||
"}\n"
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"response = openai.ChatCompletion.create(\n",
|
"from typing import List\n",
|
||||||
" model=\"gpt-3.5-turbo-0613\",\n",
|
"\n",
|
||||||
|
"# Define Enum class for multiple labels\n",
|
||||||
|
"class MultiLabels(str, enum.Enum):\n",
|
||||||
|
" TECH_ISSUE = \"tech_issue\"\n",
|
||||||
|
" BILLING = \"billing\"\n",
|
||||||
|
" GENERAL_QUERY = \"general_query\"\n",
|
||||||
|
"\n",
|
||||||
|
"# Define the multi-class prediction model\n",
|
||||||
|
"class MultiClassPrediction(BaseModel):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Class for a multi-class label prediction.\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" class_labels: List[MultiLabels]\n",
|
||||||
|
"\n",
|
||||||
|
"def multi_classify(data: str) -> MultiClassPrediction:\n",
|
||||||
|
" \"\"\"Perform multi-label classification on the input text.\"\"\"\n",
|
||||||
|
" return client.chat.completions.create(\n",
|
||||||
|
" model=\"gpt-3.5-turbo-0613\",\n",
|
||||||
|
" response_model=MultiClassPrediction,\n",
|
||||||
|
" messages=[\n",
|
||||||
|
" {\n",
|
||||||
|
" \"role\": \"user\",\n",
|
||||||
|
" \"content\": f\"Classify the following support ticket: {data}\",\n",
|
||||||
|
" },\n",
|
||||||
|
" ],\n",
|
||||||
|
" ) # type: ignore\n",
|
||||||
|
"\n",
|
||||||
|
"# Test multi-label classification\n",
|
||||||
|
"ticket = \"My account is locked and I can't access my billing info.\"\n",
|
||||||
|
"prediction = multi_classify(ticket)\n",
|
||||||
|
"print(prediction)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 33,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"question='What is the meaning of life?' answer='The meaning of life, according to the Devil, is to live a life of sin and debauchery.'\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from typing_extensions import Annotated\n",
|
||||||
|
"from pydantic import BaseModel, BeforeValidator\n",
|
||||||
|
"\n",
|
||||||
|
"from instructor import llm_validator\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"question = \"What is the meaning of life?\"\n",
|
||||||
|
"context = \"The according to the devil the meaning of live is to live a life of sin and debauchery.\"\n",
|
||||||
|
"\n",
|
||||||
|
"class QuestionAnswer(BaseModel):\n",
|
||||||
|
" question: str\n",
|
||||||
|
" answer: str\n",
|
||||||
|
"\n",
|
||||||
|
"qa: QuestionAnswer = client.chat.completions.create(\n",
|
||||||
|
" model=\"gpt-3.5-turbo\",\n",
|
||||||
|
" response_model=QuestionAnswer,\n",
|
||||||
" messages=[\n",
|
" messages=[\n",
|
||||||
" {\"role\": \"user\", \"content\": \"What's the weather like in Boston?\"}\n",
|
" {\n",
|
||||||
" ]\n",
|
" \"role\": \"system\",\n",
|
||||||
|
" \"content\": \"You are a system that answers questions based on the context. answer exactly what the question asks using the context.\",\n",
|
||||||
|
" },\n",
|
||||||
|
" {\n",
|
||||||
|
" \"role\": \"user\",\n",
|
||||||
|
" \"content\": f\"using the context: {context}\\n\\nAnswer the following question: {question}\",\n",
|
||||||
|
" },\n",
|
||||||
|
" ],\n",
|
||||||
")\n",
|
")\n",
|
||||||
"print(response)"
|
"print(qa)\n",
|
||||||
|
"\n",
|
||||||
|
"class QuestionAnswerNoEvil(BaseModel):\n",
|
||||||
|
" question: str\n",
|
||||||
|
" answer: Annotated[\n",
|
||||||
|
" str,\n",
|
||||||
|
" BeforeValidator(\n",
|
||||||
|
" llm_validator(\"don't say objectionable things\", allow_override=True)\n",
|
||||||
|
" ),\n",
|
||||||
|
" ]\n",
|
||||||
|
"\n",
|
||||||
|
"try:\n",
|
||||||
|
" qa: QuestionAnswerNoEvil = client.chat.completions.create(\n",
|
||||||
|
" model=\"gpt-3.5-turbo\",\n",
|
||||||
|
" response_model=QuestionAnswerNoEvil,\n",
|
||||||
|
" messages=[\n",
|
||||||
|
" {\n",
|
||||||
|
" \"role\": \"system\",\n",
|
||||||
|
" \"content\": \"You are a system that answers questions based on the context. answer exactly what the question asks using the context.\",\n",
|
||||||
|
" },\n",
|
||||||
|
" {\n",
|
||||||
|
" \"role\": \"user\",\n",
|
||||||
|
" \"content\": f\"using the context: {context}\\n\\nAnswer the following question: {question}\",\n",
|
||||||
|
" },\n",
|
||||||
|
" ],\n",
|
||||||
|
" )\n",
|
||||||
|
"except Exception as e:\n",
|
||||||
|
" print(e)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 42,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"question='What did the author do during college?' answer=[Fact(fact='The author, Jason Liu, studied Computational Mathematics and Physics in university.', substring_quote=['Computational Mathematics'])]\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import re\n",
|
||||||
|
"from typing import List\n",
|
||||||
|
"\n",
|
||||||
|
"from pydantic import Field, BaseModel, model_validator, FieldValidationInfo\n",
|
||||||
|
"\n",
|
||||||
|
"class Fact(BaseModel):\n",
|
||||||
|
" fact: str = Field(...)\n",
|
||||||
|
" substring_quote: List[str] = Field(...)\n",
|
||||||
|
"\n",
|
||||||
|
" @model_validator(mode=\"after\")\n",
|
||||||
|
" def validate_sources(self, info: FieldValidationInfo) -> \"Fact\":\n",
|
||||||
|
" text_chunks = info.context.get(\"text_chunk\", None)\n",
|
||||||
|
" spans = list(self.get_spans(text_chunks))\n",
|
||||||
|
" self.substring_quote = [text_chunks[span[0] : span[1]] for span in spans]\n",
|
||||||
|
" return self\n",
|
||||||
|
"\n",
|
||||||
|
" def get_spans(self, context):\n",
|
||||||
|
" for quote in self.substring_quote:\n",
|
||||||
|
" yield from self._get_span(quote, context)\n",
|
||||||
|
"\n",
|
||||||
|
" def _get_span(self, quote, context):\n",
|
||||||
|
" for match in re.finditer(re.escape(quote), context):\n",
|
||||||
|
" yield match.span()\n",
|
||||||
|
"\n",
|
||||||
|
"class QuestionAnswer(BaseModel):\n",
|
||||||
|
" question: str = Field(...)\n",
|
||||||
|
" answer: List[Fact] = Field(...)\n",
|
||||||
|
"\n",
|
||||||
|
" @model_validator(mode=\"after\")\n",
|
||||||
|
" def validate_sources(self) -> \"QuestionAnswer\":\n",
|
||||||
|
" self.answer = [fact for fact in self.answer if len(fact.substring_quote) > 0]\n",
|
||||||
|
" return self\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"def ask_ai(question: str, context: str) -> QuestionAnswer:\n",
|
||||||
|
" return client.chat.completions.create(\n",
|
||||||
|
" model=\"gpt-3.5-turbo-0613\",\n",
|
||||||
|
" temperature=0.0,\n",
|
||||||
|
" response_model=QuestionAnswer,\n",
|
||||||
|
" messages=[\n",
|
||||||
|
" {\"role\": \"system\", \"content\": \"You are a world class algorithm to answer questions with correct and exact citations.\"},\n",
|
||||||
|
" {\"role\": \"user\", \"content\": f\"{context}\"},\n",
|
||||||
|
" {\"role\": \"user\", \"content\": f\"Question: {question}\"}\n",
|
||||||
|
" ],\n",
|
||||||
|
" validation_context={\"text_chunk\": context},\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
"question = \"What did the author do during college?\"\n",
|
||||||
|
"context = \"\"\"\n",
|
||||||
|
"My name is Jason Liu, and I grew up in Toronto Canada but I was born in China.\n",
|
||||||
|
"I went to an arts high school but in university I studied Computational Mathematics and physics.\n",
|
||||||
|
"As part of coop I worked at many companies including Stitchfix, Facebook.\n",
|
||||||
|
"I also started the Data Science club at the University of Waterloo and I was the president of the club for 2 years.\n",
|
||||||
|
"\"\"\"\n",
|
||||||
|
"\n",
|
||||||
|
"qa = ask_ai(question, context)\n",
|
||||||
|
"print(qa)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
import ctypes
|
import ctypes
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
|
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, Protocol
|
||||||
|
@ -31,7 +32,7 @@ class LlamaChatCompletionHandler(Protocol):
|
||||||
response_format: Optional[
|
response_format: Optional[
|
||||||
llama_types.ChatCompletionRequestResponseFormat
|
llama_types.ChatCompletionRequestResponseFormat
|
||||||
] = None,
|
] = None,
|
||||||
max_tokens: int = 256,
|
max_tokens: Optional[int] = None,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
|
@ -292,7 +293,7 @@ def register_chat_format(name: str):
|
||||||
response_format: Optional[
|
response_format: Optional[
|
||||||
llama_types.ChatCompletionRequestResponseFormat
|
llama_types.ChatCompletionRequestResponseFormat
|
||||||
] = None,
|
] = None,
|
||||||
max_tokens: int = 256,
|
max_tokens: Optional[int] = None,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
|
@ -573,13 +574,15 @@ def functionary_chat_handler(
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
functions: Optional[List[llama_types.ChatCompletionFunction]] = None,
|
||||||
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
|
function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None,
|
||||||
|
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
|
||||||
|
tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None,
|
||||||
temperature: float = 0.2,
|
temperature: float = 0.2,
|
||||||
top_p: float = 0.95,
|
top_p: float = 0.95,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
stop: Optional[Union[str, List[str]]] = [],
|
stop: Optional[Union[str, List[str]]] = [],
|
||||||
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
|
response_format: Optional[llama_types.ChatCompletionRequestResponseFormat] = None,
|
||||||
max_tokens: int = 256,
|
max_tokens: Optional[int] = None,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
|
@ -594,57 +597,80 @@ def functionary_chat_handler(
|
||||||
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
|
) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]:
|
||||||
SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
|
SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary"""
|
||||||
|
|
||||||
def generate_schema_from_functions(
|
def generate_type_definition(param: Dict[str, llama_types.JsonType], indent_level: int, shared_defs) -> str:
|
||||||
functions: List[llama_types.ChatCompletionFunctions],
|
indent = ' ' * indent_level
|
||||||
namespace: str = "functions",
|
if '$ref' in param:
|
||||||
):
|
# Reference to a shared definition
|
||||||
"""
|
ref_name = param['$ref'].split('/')[-1] # Extract the type name from the reference
|
||||||
Convert functions schema to a schema that language models can understand.
|
return ref_name
|
||||||
"""
|
elif param.get('type') == 'array':
|
||||||
|
items = param.get('items', {})
|
||||||
|
item_type = generate_type_definition(items, indent_level + 1, shared_defs)
|
||||||
|
return f"Array<{item_type}>"
|
||||||
|
elif param.get('type') == 'object':
|
||||||
|
properties = param.get('properties', {})
|
||||||
|
nested_schema = "{\n"
|
||||||
|
for nested_param_name, nested_param in properties.items():
|
||||||
|
nested_param_type = generate_type_definition(nested_param, indent_level + 1, shared_defs)
|
||||||
|
nested_schema += f"{indent} {nested_param_name}: {nested_param_type},\n"
|
||||||
|
nested_schema += indent + "}"
|
||||||
|
return nested_schema
|
||||||
|
elif 'enum' in param:
|
||||||
|
# Enum type
|
||||||
|
return " | ".join([f'"{enum_value}"' for enum_value in param['enum']])
|
||||||
|
else:
|
||||||
|
# Simple type
|
||||||
|
return param.get('type', 'any')
|
||||||
|
|
||||||
schema = (
|
def generate_shared_definitions(shared_defs, indent_level: int) -> str:
|
||||||
"// Supported function definitions that should be called when necessary.\n"
|
indent = ' ' * indent_level
|
||||||
)
|
shared_definitions = ""
|
||||||
|
for def_name, def_properties in shared_defs.items():
|
||||||
|
shared_definitions += f"{indent}type {def_name} = "
|
||||||
|
if def_properties.get('type') == 'object':
|
||||||
|
shared_definitions += generate_type_definition(def_properties, indent_level, shared_defs)
|
||||||
|
elif 'enum' in def_properties:
|
||||||
|
# Enum type
|
||||||
|
shared_definitions += " | ".join([f'"{enum_value}"' for enum_value in def_properties['enum']])
|
||||||
|
shared_definitions += ";\n"
|
||||||
|
return shared_definitions
|
||||||
|
|
||||||
|
def generate_schema_from_functions(functions, namespace="functions") -> str:
|
||||||
|
schema = "// Supported function definitions that should be called when necessary.\n"
|
||||||
schema += f"namespace {namespace} {{\n\n"
|
schema += f"namespace {namespace} {{\n\n"
|
||||||
|
|
||||||
|
# Generate shared definitions
|
||||||
|
shared_definitions = {}
|
||||||
|
for function in functions:
|
||||||
|
parameters = function.get("parameters", {})
|
||||||
|
shared_definitions.update(parameters.get("$defs", {}))
|
||||||
|
|
||||||
|
schema += generate_shared_definitions(shared_definitions, 1)
|
||||||
|
|
||||||
for function in functions:
|
for function in functions:
|
||||||
# Convert a Function object to dict, if necessary
|
|
||||||
function_name = function["name"]
|
function_name = function["name"]
|
||||||
description = function.get("description", "")
|
description = function.get("description", "")
|
||||||
schema += f"// {description}\n"
|
parameters = function.get("parameters", {})
|
||||||
schema += f"type {function_name}"
|
|
||||||
|
|
||||||
parameters = function.get("parameters", None)
|
|
||||||
schema += " = (_: {\n"
|
|
||||||
required_params = parameters.get("required", [])
|
required_params = parameters.get("required", [])
|
||||||
|
|
||||||
|
schema += f" // {description}\n"
|
||||||
|
schema += f" type {function_name} = (_: {{\n"
|
||||||
|
|
||||||
for param_name, param in parameters.get("properties", {}).items():
|
for param_name, param in parameters.get("properties", {}).items():
|
||||||
# Param Description
|
param_description = param.get("description", "")
|
||||||
description = param.get("description")
|
param_type = generate_type_definition(param, 2, shared_definitions)
|
||||||
if description is not None:
|
optional_indicator = "" if param_name in required_params else "?"
|
||||||
schema += f"// {description}\n"
|
schema += f" // {param_description}\n"
|
||||||
|
schema += f" {param_name}{optional_indicator}: {param_type},\n"
|
||||||
# Param Name
|
schema += " }) => any;\n\n"
|
||||||
schema += f"{param_name}"
|
|
||||||
if param_name not in required_params:
|
|
||||||
schema += "?"
|
|
||||||
|
|
||||||
# Param Type
|
|
||||||
param_type = param.get("type", "any")
|
|
||||||
if param_type == "integer":
|
|
||||||
param_type = "number"
|
|
||||||
if "enum" in param:
|
|
||||||
param_type = " | ".join([f'"{v}"' for v in param["enum"]])
|
|
||||||
schema += f": {param_type},\n"
|
|
||||||
|
|
||||||
schema += "}) => any;\n\n"
|
|
||||||
|
|
||||||
schema += f"}} // namespace {namespace}"
|
|
||||||
|
|
||||||
|
schema += "}} // namespace {}\n".format(namespace)
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
def prepare_messages_for_inference(
|
def prepare_messages_for_inference(
|
||||||
messages: List[llama_types.ChatCompletionRequestMessage],
|
messages: List[llama_types.ChatCompletionRequestMessage],
|
||||||
functions: Optional[List[llama_types.ChatCompletionFunctions]] = None,
|
functions: Optional[List[llama_types.ChatCompletionFunctions]] = None,
|
||||||
|
tools: Optional[List[llama_types.ChatCompletionTool]] = None,
|
||||||
):
|
):
|
||||||
all_messages: List[llama_types.ChatCompletionRequestMessage] = []
|
all_messages: List[llama_types.ChatCompletionRequestMessage] = []
|
||||||
if functions is not None:
|
if functions is not None:
|
||||||
|
@ -653,6 +679,15 @@ def functionary_chat_handler(
|
||||||
role="system", content=generate_schema_from_functions(functions)
|
role="system", content=generate_schema_from_functions(functions)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if tools is not None:
|
||||||
|
all_messages.append(
|
||||||
|
llama_types.ChatCompletionRequestSystemMessage(
|
||||||
|
role="system", content=generate_schema_from_functions(
|
||||||
|
[tool["function"] for tool in tools if tool["type"] == "function"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
all_messages.append(
|
all_messages.append(
|
||||||
llama_types.ChatCompletionRequestSystemMessage(
|
llama_types.ChatCompletionRequestSystemMessage(
|
||||||
|
@ -685,16 +720,24 @@ def functionary_chat_handler(
|
||||||
return f"function name={msg['name']}:\n{msg['content']}\n"
|
return f"function name={msg['name']}:\n{msg['content']}\n"
|
||||||
elif msg["role"] == "function" and "function_call" in msg:
|
elif msg["role"] == "function" and "function_call" in msg:
|
||||||
return f"function name={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n"
|
return f"function name={msg['function_call']['name']}:\n{msg['function_call']['arguments']}\n"
|
||||||
|
elif msg["role"] == "tool":
|
||||||
|
if msg["content"] is not None:
|
||||||
|
return f"function name={msg['tool_call_id']}:\n{msg['content']}\n"
|
||||||
|
else:
|
||||||
|
return f"function name={msg['tool_call_id']}\n"
|
||||||
elif msg["role"] == "user":
|
elif msg["role"] == "user":
|
||||||
if msg["content"] is None:
|
if msg["content"] is None:
|
||||||
return "user:\n</s>"
|
return "user:\n</s></s>\n"
|
||||||
else:
|
else:
|
||||||
return f"user:\n</s>{msg['content']}\n"
|
return f"user:\n</s>{msg['content']}</s>\n"
|
||||||
elif msg["role"] == "assistant":
|
elif msg["role"] == "assistant":
|
||||||
if msg["content"] is not None and "function_call" in msg:
|
if msg["content"] is not None and "function_call" in msg:
|
||||||
return f"assistant:\n{msg['content']}\nassistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>"
|
return f"assistant:\n{msg['content']}\nassistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>\n"
|
||||||
elif "function_call" in msg:
|
elif "function_call" in msg:
|
||||||
return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>"
|
return f"assistant to={msg['function_call']['name']}:\n{msg['function_call']['arguments']}</s>\n"
|
||||||
|
elif "tool_calls" in msg and len(msg["tool_calls"]) > 0:
|
||||||
|
for tool_call in msg["tool_calls"]: # NOTE: probably doesn't work with the functionary model
|
||||||
|
return f"assistant to={tool_call['id']}:\n{tool_call['function']['arguments']}</s>\n"
|
||||||
elif msg["content"] is None:
|
elif msg["content"] is None:
|
||||||
return "assistant"
|
return "assistant"
|
||||||
else:
|
else:
|
||||||
|
@ -703,8 +746,14 @@ def functionary_chat_handler(
|
||||||
raise ValueError(f"Unsupported role: {msg['role']}")
|
raise ValueError(f"Unsupported role: {msg['role']}")
|
||||||
|
|
||||||
return "".join([message_to_str(msg) for msg in all_messages])
|
return "".join([message_to_str(msg) for msg in all_messages])
|
||||||
|
|
||||||
|
if tools is not None:
|
||||||
|
functions = [tool["function"] for tool in tools if tool["type"] == "function"]
|
||||||
|
|
||||||
|
if tool_choice is not None:
|
||||||
|
function_call = tool_choice if isinstance(tool_choice, str) else tool_choice["function"]
|
||||||
|
|
||||||
prompt = prepare_messages_for_inference(messages, functions)
|
prompt = prepare_messages_for_inference(messages, functions, tools)
|
||||||
|
|
||||||
if function_call is None and (functions is None or len(functions) == 0):
|
if function_call is None and (functions is None or len(functions) == 0):
|
||||||
completion_or_completion_chunks = llama.create_completion(
|
completion_or_completion_chunks = llama.create_completion(
|
||||||
|
@ -737,27 +786,68 @@ def functionary_chat_handler(
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
completion_text = completion["choices"][0]["text"]
|
completion_text = completion["choices"][0]["text"]
|
||||||
# strip " to=functions." and ending ":"
|
# strip " to=functions." and ending ":"
|
||||||
function_call = completion_text[14:-1]
|
function_call = completion_text.split(".")[-1][:-1]
|
||||||
new_prompt = prompt + completion_text + stop
|
new_prompt = prompt + completion_text + stop
|
||||||
elif isinstance(function_call, str) and function_call != "none":
|
elif isinstance(function_call, str) and function_call != "none":
|
||||||
new_prompt = prompt + f"assistant:\n"
|
new_prompt = prompt + f":\n"
|
||||||
elif isinstance(function_call, dict):
|
elif isinstance(function_call, dict):
|
||||||
new_prompt = prompt + f"assistant to={function_call['name']}:\n"
|
new_prompt = prompt + f" to=functions.{function_call['name']}:\n"
|
||||||
function_call = function_call["name"]
|
function_call = function_call["name"]
|
||||||
else:
|
else:
|
||||||
new_prompt = prompt + f"assistant:\n"
|
new_prompt = prompt + f":\n"
|
||||||
|
|
||||||
|
function_body = None
|
||||||
|
for function in functions or []:
|
||||||
|
if function["name"] == function_call:
|
||||||
|
function_body = function["parameters"]
|
||||||
|
break
|
||||||
|
for tool in tools or []:
|
||||||
|
if tool["type"] == "function" and tool["function"]["name"] == function_call:
|
||||||
|
function_body = tool["function"]["parameters"]
|
||||||
|
break
|
||||||
|
|
||||||
|
if function_body is not None:
|
||||||
|
try:
|
||||||
|
with suppress_stdout_stderr(disable=llama.verbose):
|
||||||
|
grammar_text = llama_grammar.json_schema_to_gbnf(json.dumps(function_body))
|
||||||
|
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.json_schema_to_gbnf(json.dumps(function_body)))
|
||||||
|
print(grammar_text)
|
||||||
|
except Exception as e:
|
||||||
|
if llama.verbose:
|
||||||
|
print("Failed to parse function body as JSON schema, falling back to default grammar")
|
||||||
|
print(e)
|
||||||
|
with suppress_stdout_stderr(disable=llama.verbose):
|
||||||
|
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
|
||||||
|
else:
|
||||||
|
with suppress_stdout_stderr(disable=llama.verbose):
|
||||||
|
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
|
||||||
|
|
||||||
completion: llama_types.Completion = llama.create_completion(
|
completion: llama_types.Completion = llama.create_completion(
|
||||||
prompt=new_prompt, stop=["user:", "</s>"], stream=False
|
prompt=new_prompt,
|
||||||
|
stop=["user:", "</s>"],
|
||||||
|
stream=False,
|
||||||
|
grammar=grammar,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
repeat_penalty=repeat_penalty,
|
||||||
|
tfs_z=tfs_z,
|
||||||
|
mirostat_mode=mirostat_mode,
|
||||||
|
mirostat_tau=mirostat_tau,
|
||||||
|
mirostat_eta=mirostat_eta,
|
||||||
|
model=model,
|
||||||
|
logits_processor=logits_processor,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
|
|
||||||
assert "usage" in completion
|
assert "usage" in completion
|
||||||
assert isinstance(function_call, str)
|
assert isinstance(function_call, str)
|
||||||
assert stream is False # TODO: support stream mode
|
assert stream is False # TODO: support stream mode
|
||||||
|
|
||||||
if response_format is not None and response_format["type"] == "json_object":
|
print(new_prompt)
|
||||||
with suppress_stdout_stderr(disable=llama.verbose):
|
print(completion["choices"][0]["text"])
|
||||||
grammar = llama_grammar.LlamaGrammar.from_string(llama_grammar.JSON_GBNF)
|
|
||||||
|
|
||||||
return llama_types.CreateChatCompletionResponse(
|
return llama_types.CreateChatCompletionResponse(
|
||||||
id="chat" + completion["id"],
|
id="chat" + completion["id"],
|
||||||
|
@ -768,14 +858,24 @@ def functionary_chat_handler(
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"message": {
|
"message": {
|
||||||
"role": "function",
|
"role": "assistant",
|
||||||
"content": None,
|
"content": None,
|
||||||
"function_call": {
|
"function_call": {
|
||||||
"name": function_call,
|
"name": function_call,
|
||||||
"arguments": completion["choices"][0]["text"],
|
"arguments": completion["choices"][0]["text"],
|
||||||
},
|
},
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": function_call,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": function_call,
|
||||||
|
"arguments": completion["choices"][0]["text"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
},
|
},
|
||||||
"finish_reason": "function_call",
|
"finish_reason": "tool_calls",
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
usage=completion["usage"],
|
usage=completion["usage"],
|
||||||
|
@ -834,7 +934,7 @@ class Llava15ChatHandler:
|
||||||
response_format: Optional[
|
response_format: Optional[
|
||||||
llama_types.ChatCompletionRequestResponseFormat
|
llama_types.ChatCompletionRequestResponseFormat
|
||||||
] = None,
|
] = None,
|
||||||
max_tokens: int = 256,
|
max_tokens: Optional[int] = None,
|
||||||
presence_penalty: float = 0.0,
|
presence_penalty: float = 0.0,
|
||||||
frequency_penalty: float = 0.0,
|
frequency_penalty: float = 0.0,
|
||||||
repeat_penalty: float = 1.1,
|
repeat_penalty: float = 1.1,
|
||||||
|
|
Loading…
Reference in a new issue