From 345215a76cf57b769474ea5dc1aefc5ccfb06d5c Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Tue, 13 Feb 2024 23:02:50 -0500 Subject: [PATCH] fix: more chatml-function-calling fixes --- llama_cpp/llama_chat_format.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 809a827..7f365e3 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -2061,12 +2061,12 @@ def chatml_function_calling( "\nfunctions.:" '\n{ "arg1": "value1", "arg2": "value2" }' "{% endif %}" - "\n<|im_end|>\n" + "<|im_end|>\n" "{% endif %}" # User message "{% if message.role == 'user' %}" "{{ message.content }}" - "\n<|im_end|>\n" + "<|im_end|>\n" "{% endif %}" # Assistant message "{% if message.role == 'assistant' %}" @@ -2076,7 +2076,7 @@ def chatml_function_calling( "message:\n" "{% endif %}" "{{ message.content }}" - "\n<|im_end|>\n" + "<|im_end|>\n" "{% endif %}" ## Function calls "{% if 'tool_calls' in message %}" @@ -2084,11 +2084,11 @@ def chatml_function_calling( "functions.{{ tool_call.function.name }}:\n" "{{ tool_call.function.arguments }}" "{% endfor %}" - "\n<|im_end|>\n" + "<|im_end|>\n" "{% endif %}" "{% endif %}" "{% endfor %}" - "{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" + "{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" ) template_renderer = jinja2.Environment( loader=jinja2.BaseLoader(), @@ -2120,6 +2120,8 @@ def chatml_function_calling( }, } + stop = [stop, "<|im_end|>"] if isinstance(stop, str) else stop + ["<|im_end|>"] if stop else ["<|im_end|>"] + # Case 1: No tool choice by user if ( tool_choice is None