fix: more chatml-function-calling fixes

This commit is contained in:
Andrei Betlen 2024-02-13 23:02:50 -05:00
parent b1637c2319
commit 345215a76c

View file

@ -2061,12 +2061,12 @@ def chatml_function_calling(
"\nfunctions.<function_name>:" "\nfunctions.<function_name>:"
'\n{ "arg1": "value1", "arg2": "value2" }' '\n{ "arg1": "value1", "arg2": "value2" }'
"{% endif %}" "{% endif %}"
"\n<|im_end|>\n" "<|im_end|>\n"
"{% endif %}" "{% endif %}"
# User message # User message
"{% if message.role == 'user' %}" "{% if message.role == 'user' %}"
"{{ message.content }}" "{{ message.content }}"
"\n<|im_end|>\n" "<|im_end|>\n"
"{% endif %}" "{% endif %}"
# Assistant message # Assistant message
"{% if message.role == 'assistant' %}" "{% if message.role == 'assistant' %}"
@ -2076,7 +2076,7 @@ def chatml_function_calling(
"message:\n" "message:\n"
"{% endif %}" "{% endif %}"
"{{ message.content }}" "{{ message.content }}"
"\n<|im_end|>\n" "<|im_end|>\n"
"{% endif %}" "{% endif %}"
## Function calls ## Function calls
"{% if 'tool_calls' in message %}" "{% if 'tool_calls' in message %}"
@ -2084,11 +2084,11 @@ def chatml_function_calling(
"functions.{{ tool_call.function.name }}:\n" "functions.{{ tool_call.function.name }}:\n"
"{{ tool_call.function.arguments }}" "{{ tool_call.function.arguments }}"
"{% endfor %}" "{% endfor %}"
"\n<|im_end|>\n" "<|im_end|>\n"
"{% endif %}" "{% endif %}"
"{% endif %}" "{% endif %}"
"{% endfor %}" "{% endfor %}"
"{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" "{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
) )
template_renderer = jinja2.Environment( template_renderer = jinja2.Environment(
loader=jinja2.BaseLoader(), loader=jinja2.BaseLoader(),
@ -2120,6 +2120,8 @@ def chatml_function_calling(
}, },
} }
stop = [stop, "<|im_end|>"] if isinstance(stop, str) else stop + ["<|im_end|>"] if stop else ["<|im_end|>"]
# Case 1: No tool choice by user # Case 1: No tool choice by user
if ( if (
tool_choice is None tool_choice is None