diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index c89cce8..5bda163 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -188,6 +188,10 @@ class Jinja2ChatFormatter(ChatFormatter): self, *, messages: List[llama_types.ChatCompletionRequestMessage], + functions: Optional[List[llama_types.ChatCompletionFunction]] = None, + function_call: Optional[llama_types.ChatCompletionRequestFunctionCall] = None, + tools: Optional[List[llama_types.ChatCompletionTool]] = None, + tool_choice: Optional[llama_types.ChatCompletionToolChoiceOption] = None, **kwargs: Any, ) -> ChatFormatterResponse: def raise_exception(message: str): @@ -199,6 +203,10 @@ class Jinja2ChatFormatter(ChatFormatter): bos_token=self.bos_token, raise_exception=raise_exception, add_generation_prompt=self.add_generation_prompt, + functions=functions, + function_call=function_call, + tools=tools, + tool_choice=tool_choice, ) return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token]) @@ -288,6 +296,183 @@ def _convert_completion_to_chat( return _convert_text_completion_to_chat(completion) +def _convert_completion_to_chat_function( + tool_name: str, + completion_or_chunks: Union[ + llama_types.CreateCompletionResponse, + Iterator[llama_types.CreateCompletionStreamResponse], + ], + stream: bool, +): + if not stream: + completion: llama_types.CreateCompletionResponse = completion_or_chunks # type: ignore + assert "usage" in completion + tool_id = "call_" + "_0_" + tool_name + "_" + completion["id"] + # TODO: Fix for legacy function calls + chat_completion: llama_types.CreateChatCompletionResponse = { + "id": "chat" + completion["id"], + "object": "chat.completion", + "created": completion["created"], + "model": completion["model"], + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": None, + "function_call": { + "name": tool_name, + "arguments": completion["choices"][0]["text"], + }, + "tool_calls": [ + { + "id": tool_id, + "type": "function", + "function": { + "name": tool_name, + "arguments": completion["choices"][0]["text"], + }, + } + ], + }, + "finish_reason": "tool_calls", + } + ], + "usage": completion["usage"], + } + return chat_completion + else: + chunks: Iterator[llama_types.CreateCompletionStreamResponse] = completion_or_chunks # type: ignore + + def _stream_response_to_function_stream( + chunks: Iterator[llama_types.CreateCompletionStreamResponse], + ) -> Iterator[llama_types.CreateChatCompletionStreamResponse]: + # blank first message + first = True + id_ = None + created = None + model = None + tool_id = None + for chunk in chunks: + if first: + id_ = "chat" + chunk["id"] + created = chunk["created"] + model = chunk["model"] + tool_id = "call_" + "_0_" + tool_name + "_" + chunk["id"] + yield { + "id": id_, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [ + { + "index": 0, + "finish_reason": None, + "logprobs": None, + "delta": { + "role": "assistant", + "content": None, + "function_call": None, + "tool_calls": None, + }, + } + ], + } + yield { + "id": "chat" + chunk["id"], + "object": "chat.completion.chunk", + "created": chunk["created"], + "model": chunk["model"], + "choices": [ + { + "index": 0, + "finish_reason": None, + "logprobs": None, + "delta": { + "role": None, + "content": None, + "function_call": { + "name": tool_name, + "arguments": chunk["choices"][0]["text"], + }, + "tool_calls": [ + { + "index": 0, + "id": tool_id, + "type": "function", + "function": { + "name": tool_name, + "arguments": "", + }, + } + ], + }, + } + ], + } + first = False + continue + assert tool_id is not None + yield { + "id": "chat" + chunk["id"], + "object": "chat.completion.chunk", + "created": chunk["created"], + "model": chunk["model"], + "choices": [ + { + "index": 0, + "finish_reason": None, + "logprobs": None, + "delta": { + "role": None, + "content": None, + "function_call": { + "name": tool_name, + "arguments": chunk["choices"][0]["text"], + }, + "tool_calls": [ + { + "index": 0, + "id": tool_id, + "type": "function", + "function": { + "name": tool_name, + "arguments": chunk["choices"][0][ + "text" + ], + }, + } + ], + }, + } + ], + } + + if id_ is not None and created is not None and model is not None: + yield { + "id": id_, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": [ + { + "index": 0, + "finish_reason": "tool_calls", + "logprobs": None, + "delta": { + "role": None, + "content": None, + "function_call": None, + "tool_calls": None, + }, + } + ], + } + + return _stream_response_to_function_stream(chunks) + + + def chat_formatter_to_chat_completion_handler( chat_formatter: ChatFormatter, ) -> LlamaChatCompletionHandler: @@ -331,6 +516,8 @@ def chat_formatter_to_chat_completion_handler( messages=messages, functions=functions, function_call=function_call, + tools=tools, + tool_choice=tool_choice, ) prompt = result.prompt if result.stop is not None: @@ -341,6 +528,47 @@ def chat_formatter_to_chat_completion_handler( if response_format is not None and response_format["type"] == "json_object": grammar = _grammar_for_response_format(response_format, verbose=llama.verbose) + # Convert legacy functions to tools + if functions is not None: + tools = [ + { + "type": "function", + "function": function, + } + for function in functions + ] + + # Convert legacy function_call to tool_choice + if function_call is not None: + if isinstance(function_call, str) and ( + function_call == "none" or function_call == "auto" + ): + tool_choice = function_call + if isinstance(function_call, dict) and "name" in function_call: + tool_choice = { + "type": "function", + "function": { + "name": function_call["name"], + }, + } + + tool = None + if tool_choice is not None and isinstance(tool_choice, dict) and tools is not None: + name = tool_choice["function"]["name"] + tool = next((t for t in tools if t["function"]["name"] == name), None) + if tool is None: + raise ValueError(f"Tool choice '{name}' not found in tools.") + schema = tool["function"]["parameters"] + try: + # create grammar from json schema + grammar = llama_grammar.LlamaGrammar.from_json_schema( + json.dumps(schema), verbose=llama.verbose + ) + except Exception as e: + grammar = llama_grammar.LlamaGrammar.from_string( + llama_grammar.JSON_GBNF, verbose=llama.verbose + ) + completion_or_chunks = llama.create_completion( prompt=prompt, temperature=temperature, @@ -364,6 +592,11 @@ def chat_formatter_to_chat_completion_handler( grammar=grammar, logit_bias=logit_bias, ) + if tool is not None: + tool_name = tool["function"]["name"] + return _convert_completion_to_chat_function( + tool_name, completion_or_chunks, stream + ) return _convert_completion_to_chat(completion_or_chunks, stream=stream) return chat_completion_handler @@ -2198,181 +2431,6 @@ def chatml_function_calling( stream=stream, ) - def _convert_completion_to_chat_function( - tool_name: str, - completion_or_chunks: Union[ - llama_types.CreateCompletionResponse, - Iterator[llama_types.CreateCompletionStreamResponse], - ], - stream: bool, - ): - if not stream: - completion: llama_types.CreateCompletionResponse = completion_or_chunks # type: ignore - assert "usage" in completion - tool_id = "call_" + "_0_" + tool_name + "_" + completion["id"] - # TODO: Fix for legacy function calls - chat_completion: llama_types.CreateChatCompletionResponse = { - "id": "chat" + completion["id"], - "object": "chat.completion", - "created": completion["created"], - "model": completion["model"], - "choices": [ - { - "index": 0, - "message": { - "role": "assistant", - "content": None, - "function_call": { - "name": tool_name, - "arguments": completion["choices"][0]["text"], - }, - "tool_calls": [ - { - "id": tool_id, - "type": "function", - "function": { - "name": tool_name, - "arguments": completion["choices"][0]["text"], - }, - } - ], - }, - "finish_reason": "tool_calls", - } - ], - "usage": completion["usage"], - } - return chat_completion - else: - chunks: Iterator[llama_types.CreateCompletionStreamResponse] = completion_or_chunks # type: ignore - - def _stream_response_to_function_stream( - chunks: Iterator[llama_types.CreateCompletionStreamResponse], - ) -> Iterator[llama_types.CreateChatCompletionStreamResponse]: - # blank first message - first = True - id_ = None - created = None - model = None - tool_id = None - for chunk in chunks: - if first: - id_ = "chat" + chunk["id"] - created = chunk["created"] - model = chunk["model"] - tool_id = "call_" + "_0_" + tool_name + "_" + chunk["id"] - yield { - "id": id_, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [ - { - "index": 0, - "finish_reason": None, - "logprobs": None, - "delta": { - "role": "assistant", - "content": None, - "function_call": None, - "tool_calls": None, - }, - } - ], - } - yield { - "id": "chat" + chunk["id"], - "object": "chat.completion.chunk", - "created": chunk["created"], - "model": chunk["model"], - "choices": [ - { - "index": 0, - "finish_reason": None, - "logprobs": None, - "delta": { - "role": None, - "content": None, - "function_call": { - "name": tool_name, - "arguments": chunk["choices"][0]["text"], - }, - "tool_calls": [ - { - "index": 0, - "id": tool_id, - "type": "function", - "function": { - "name": tool_name, - "arguments": "", - }, - } - ], - }, - } - ], - } - first = False - continue - assert tool_id is not None - yield { - "id": "chat" + chunk["id"], - "object": "chat.completion.chunk", - "created": chunk["created"], - "model": chunk["model"], - "choices": [ - { - "index": 0, - "finish_reason": None, - "logprobs": None, - "delta": { - "role": None, - "content": None, - "function_call": { - "name": tool_name, - "arguments": chunk["choices"][0]["text"], - }, - "tool_calls": [ - { - "index": 0, - "id": tool_id, - "type": "function", - "function": { - "name": tool_name, - "arguments": chunk["choices"][0][ - "text" - ], - }, - } - ], - }, - } - ], - } - - if id_ is not None and created is not None and model is not None: - yield { - "id": id_, - "object": "chat.completion.chunk", - "created": created, - "model": model, - "choices": [ - { - "index": 0, - "finish_reason": "tool_calls", - "logprobs": None, - "delta": { - "role": None, - "content": None, - "function_call": None, - "tool_calls": None, - }, - } - ], - } - - return _stream_response_to_function_stream(chunks) - # Case 2: Tool choice by user if isinstance(tool_choice, dict): tool_name = tool_choice["function"]["name"]