diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 66e40ae..809a827 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -2088,6 +2088,7 @@ def chatml_function_calling( "{% endif %}" "{% endif %}" "{% endfor %}" + "{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" ) template_renderer = jinja2.Environment( loader=jinja2.BaseLoader(), @@ -2130,6 +2131,7 @@ def chatml_function_calling( messages=messages, tools=[], tool_calls=None, + add_generation_prompt=True, ) if response_format is not None and response_format["type"] == "json_object": try: @@ -2363,6 +2365,7 @@ def chatml_function_calling( messages=messages, tools=tools, tool_calls=True, + add_generation_prompt=True, ) prompt += f"functions.{tool_name}:\n" try: @@ -2420,6 +2423,7 @@ def chatml_function_calling( messages=messages, tools=tools, tool_calls=True, + add_generation_prompt=True, ) completion_or_chunks = llama.create_completion( prompt=prompt,