feat: Support multiple chat templates - step 1 (#1396)
* Support multiple chat templates - step 1 As a first step, allow user to to select template from metadata with chat_format parameter in the form of `chat_template.name`. * register chat templates to self.chat_formats instead of globally * Don't expose internal chat handlers yet --------- Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
parent
bf66a283e8
commit
5ab40e6167
1 changed files with 28 additions and 22 deletions
|
@ -378,6 +378,7 @@ class Llama:
|
|||
|
||||
self.chat_format = chat_format
|
||||
self.chat_handler = chat_handler
|
||||
self._chat_handlers: Dict[str, llama_chat_format.LlamaChatCompletionHandler] = {}
|
||||
|
||||
self.draft_model = draft_model
|
||||
|
||||
|
@ -409,10 +410,33 @@ class Llama:
|
|||
if self.verbose:
|
||||
print(f"Model metadata: {self.metadata}", file=sys.stderr)
|
||||
|
||||
eos_token_id = int(self.metadata.get("tokenizer.ggml.eos_token_id", self.token_eos()))
|
||||
bos_token_id = int(self.metadata.get("tokenizer.ggml.bos_token_id", self.token_bos()))
|
||||
|
||||
eos_token = self._model.token_get_text(eos_token_id)
|
||||
bos_token = self._model.token_get_text(bos_token_id)
|
||||
|
||||
# Unfortunately the llama.cpp API does not return metadata arrays, so we can't get template names from tokenizer.chat_templates
|
||||
template_choices = dict((name[10:], template) for name, template in self.metadata.items() if name.startswith("tokenizer.chat_template."))
|
||||
|
||||
if "tokenizer.chat_template" in self.metadata:
|
||||
template_choices["chat_template.default"] = self.metadata["tokenizer.chat_template"]
|
||||
|
||||
if self.verbose and template_choices:
|
||||
print(f"Available chat formats from metadata: {', '.join(template_choices.keys())}", file=sys.stderr)
|
||||
|
||||
for name, template in template_choices.items():
|
||||
self._chat_handlers[name] = llama_chat_format.Jinja2ChatFormatter(
|
||||
template=template,
|
||||
eos_token=eos_token,
|
||||
bos_token=bos_token,
|
||||
stop_token_ids=[eos_token_id],
|
||||
).to_chat_handler()
|
||||
|
||||
if (
|
||||
self.chat_format is None
|
||||
and self.chat_handler is None
|
||||
and "tokenizer.chat_template" in self.metadata
|
||||
and "chat_template.default" in template_choices
|
||||
):
|
||||
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(
|
||||
self.metadata
|
||||
|
@ -423,30 +447,12 @@ class Llama:
|
|||
if self.verbose:
|
||||
print(f"Guessed chat format: {chat_format}", file=sys.stderr)
|
||||
else:
|
||||
template = self.metadata["tokenizer.chat_template"]
|
||||
try:
|
||||
eos_token_id = int(self.metadata["tokenizer.ggml.eos_token_id"])
|
||||
except:
|
||||
eos_token_id = self.token_eos()
|
||||
try:
|
||||
bos_token_id = int(self.metadata["tokenizer.ggml.bos_token_id"])
|
||||
except:
|
||||
bos_token_id = self.token_bos()
|
||||
|
||||
eos_token = self._model.token_get_text(eos_token_id)
|
||||
bos_token = self._model.token_get_text(bos_token_id)
|
||||
|
||||
if self.verbose:
|
||||
print(f"Using gguf chat template: {template}", file=sys.stderr)
|
||||
print(f"Using gguf chat template: {template_choices['chat_template.default']}", file=sys.stderr)
|
||||
print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
|
||||
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
|
||||
|
||||
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
|
||||
template=template,
|
||||
eos_token=eos_token,
|
||||
bos_token=bos_token,
|
||||
stop_token_ids=[eos_token_id],
|
||||
).to_chat_handler()
|
||||
self.chat_format = "chat_template.default"
|
||||
|
||||
if self.chat_format is None and self.chat_handler is None:
|
||||
self.chat_format = "llama-2"
|
||||
|
@ -1719,7 +1725,7 @@ class Llama:
|
|||
Returns:
|
||||
Generated chat completion or a stream of chat completion chunks.
|
||||
"""
|
||||
handler = self.chat_handler or llama_chat_format.get_chat_completion_handler(
|
||||
handler = self.chat_handler or self._chat_handlers.get(self.chat_format) or llama_chat_format.get_chat_completion_handler(
|
||||
self.chat_format
|
||||
)
|
||||
return handler(
|
||||
|
|
Loading…
Reference in a new issue