diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 06cf9ce..705202e 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -6,7 +6,7 @@ import ctypes import dataclasses import random import string -from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol +from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, Protocol, cast import jinja2 @@ -338,6 +338,7 @@ def _convert_completion_to_chat_function( } ], }, + "logprobs": None, "finish_reason": "tool_calls", } ], @@ -1191,7 +1192,6 @@ def format_mistral_instruct( elif ( message["role"] == "assistant" and message["content"] is not None - and isinstance(message["content"], str) ): prompt += " [/INST]" + message["content"] + eos prompt += " [/INST]" @@ -1263,7 +1263,7 @@ def format_gemma( **kwargs: Any, ) -> ChatFormatterResponse: system_message = _get_system_message(messages) - if system_message is not None and system_message != "": + if system_message != "": logger.debug( "`role='system'` messages are not allowed on Google's Gemma models." ) @@ -1628,6 +1628,7 @@ def functionary_chat_handler( } ], }, + "logprobs": None, "finish_reason": "tool_calls", } ], @@ -1909,14 +1910,14 @@ def functionary_v1_v2_chat_handler( return grammar def create_completion(stop): - completion: llama_types.Completion = llama.create_completion( + completion = cast(llama_types.Completion, llama.create_completion( prompt=prompt, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, typical_p=typical_p, - stream=stream, + stream=False, stop=stop, max_tokens=max_tokens, presence_penalty=presence_penalty, @@ -1929,7 +1930,7 @@ def functionary_v1_v2_chat_handler( model=model, logits_processor=logits_processor, grammar=grammar, - ) + )) return completion @@ -2050,7 +2051,7 @@ def functionary_v1_v2_chat_handler( assert "usage" in completion assert len(function_calls) == len(function_bodies) - tool_calls = [] + tool_calls: List[llama_types.ChatCompletionMessageToolCall] = [] for function_call, function_body in zip(function_calls, function_bodies): tool_calls.append( { @@ -2070,6 +2071,12 @@ def functionary_v1_v2_chat_handler( ) # TODO: support stream mode + function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = { + "function_call": { + "name": tool_calls[0]["function"]["name"], + "arguments": tool_calls[0]["function"]["arguments"], + } + } if len(tool_calls) == 1 else {} return llama_types.CreateChatCompletionResponse( id="chat" + completion["id"], object="chat.completion", @@ -2078,14 +2085,12 @@ def functionary_v1_v2_chat_handler( choices=[ { "index": 0, + "logprobs": None, "message": { "role": "assistant", "content": None if content == "" else content, - "function_call": { - "name": tool_calls[0]["function"]["name"], - "arguments": tool_calls[0]["function"]["arguments"], - } if len(tool_calls) > 0 else None, - "tool_calls": tool_calls if len(tool_calls) > 0 else None, + "tool_calls": tool_calls, + **function_call_dict, }, "finish_reason": "tool_calls" if len(tool_calls) > 0 else "stop", } @@ -2565,8 +2570,8 @@ def chatml_function_calling( tool_name = text[len("functions.") :] tool = next((tool for tool in tools if tool["function"]["name"] == tool_name), None) if not stream: - completions = [] - completions_tool_name = [] + completions: List[llama_types.CreateCompletionResponse] = [] + completions_tool_name: List[str] = [] while tool is not None: prompt += f"functions.{tool_name}:\n" try: @@ -2603,6 +2608,7 @@ def chatml_function_calling( logits_processor=logits_processor, grammar=grammar, ) + completion_or_chunks = cast(llama_types.CreateCompletionResponse, completion_or_chunks) completions.append(completion_or_chunks) completions_tool_name.append(tool_name) prompt += completion_or_chunks["choices"][0]["text"] @@ -2631,6 +2637,7 @@ def chatml_function_calling( follow_up_gbnf_tool_grammar, verbose=llama.verbose ), ) + response = cast(llama_types.CreateCompletionResponse, response) tool_name = response["choices"][0]["text"][len("functions.") :] tool = next( @@ -2638,7 +2645,7 @@ def chatml_function_calling( ) # Merge completions - function_call = { + function_call_dict: Union[Dict[str, str], Dict[Literal["function_call"], llama_types.ChatCompletionRequestAssistantMessageFunctionCall]] = { "function_call": { "name": tool_name, "arguments": completions[0]["choices"][0]["text"], @@ -2653,6 +2660,7 @@ def chatml_function_calling( { "finish_reason": "tool_calls", "index": 0, + "logprobs": None, "message": { "role": "assistant", "content": None, @@ -2673,20 +2681,22 @@ def chatml_function_calling( zip(completions_tool_name, completions) ) ], - **function_call + **function_call_dict }, } ], "usage": { "completion_tokens": sum( - completion["usage"]["completion_tokens"] + completion["usage"]["completion_tokens"] if "usage" in completion else 0 for completion in completions ), "prompt_tokens": sum( - completion["usage"]["prompt_tokens"] for completion in completions + completion["usage"]["prompt_tokens"] if "usage" in completion else 0 + for completion in completions ), "total_tokens": sum( - completion["usage"]["total_tokens"] for completion in completions + completion["usage"]["total_tokens"] if "usage" in completion else 0 + for completion in completions ), }, }