feat: Support multiple chat templates - step 1 (#1396)

* Support multiple chat templates - step 1

As a first step, allow user to to select template from metadata with chat_format parameter in the form of `chat_template.name`.

* register chat templates to self.chat_formats instead of globally

* Don't expose internal chat handlers yet

---------

Co-authored-by: Andrei <abetlen@gmail.com>
This commit is contained in:
Sigbjørn Skjæret 2024-05-09 15:49:09 +02:00 committed by GitHub
parent bf66a283e8
commit 5ab40e6167
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -378,6 +378,7 @@ class Llama:
self.chat_format = chat_format self.chat_format = chat_format
self.chat_handler = chat_handler self.chat_handler = chat_handler
self._chat_handlers: Dict[str, llama_chat_format.LlamaChatCompletionHandler] = {}
self.draft_model = draft_model self.draft_model = draft_model
@ -409,10 +410,33 @@ class Llama:
if self.verbose: if self.verbose:
print(f"Model metadata: {self.metadata}", file=sys.stderr) print(f"Model metadata: {self.metadata}", file=sys.stderr)
eos_token_id = int(self.metadata.get("tokenizer.ggml.eos_token_id", self.token_eos()))
bos_token_id = int(self.metadata.get("tokenizer.ggml.bos_token_id", self.token_bos()))
eos_token = self._model.token_get_text(eos_token_id)
bos_token = self._model.token_get_text(bos_token_id)
# Unfortunately the llama.cpp API does not return metadata arrays, so we can't get template names from tokenizer.chat_templates
template_choices = dict((name[10:], template) for name, template in self.metadata.items() if name.startswith("tokenizer.chat_template."))
if "tokenizer.chat_template" in self.metadata:
template_choices["chat_template.default"] = self.metadata["tokenizer.chat_template"]
if self.verbose and template_choices:
print(f"Available chat formats from metadata: {', '.join(template_choices.keys())}", file=sys.stderr)
for name, template in template_choices.items():
self._chat_handlers[name] = llama_chat_format.Jinja2ChatFormatter(
template=template,
eos_token=eos_token,
bos_token=bos_token,
stop_token_ids=[eos_token_id],
).to_chat_handler()
if ( if (
self.chat_format is None self.chat_format is None
and self.chat_handler is None and self.chat_handler is None
and "tokenizer.chat_template" in self.metadata and "chat_template.default" in template_choices
): ):
chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata( chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata(
self.metadata self.metadata
@ -423,30 +447,12 @@ class Llama:
if self.verbose: if self.verbose:
print(f"Guessed chat format: {chat_format}", file=sys.stderr) print(f"Guessed chat format: {chat_format}", file=sys.stderr)
else: else:
template = self.metadata["tokenizer.chat_template"]
try:
eos_token_id = int(self.metadata["tokenizer.ggml.eos_token_id"])
except:
eos_token_id = self.token_eos()
try:
bos_token_id = int(self.metadata["tokenizer.ggml.bos_token_id"])
except:
bos_token_id = self.token_bos()
eos_token = self._model.token_get_text(eos_token_id)
bos_token = self._model.token_get_text(bos_token_id)
if self.verbose: if self.verbose:
print(f"Using gguf chat template: {template}", file=sys.stderr) print(f"Using gguf chat template: {template_choices['chat_template.default']}", file=sys.stderr)
print(f"Using chat eos_token: {eos_token}", file=sys.stderr) print(f"Using chat eos_token: {eos_token}", file=sys.stderr)
print(f"Using chat bos_token: {bos_token}", file=sys.stderr) print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
self.chat_handler = llama_chat_format.Jinja2ChatFormatter( self.chat_format = "chat_template.default"
template=template,
eos_token=eos_token,
bos_token=bos_token,
stop_token_ids=[eos_token_id],
).to_chat_handler()
if self.chat_format is None and self.chat_handler is None: if self.chat_format is None and self.chat_handler is None:
self.chat_format = "llama-2" self.chat_format = "llama-2"
@ -1719,7 +1725,7 @@ class Llama:
Returns: Returns:
Generated chat completion or a stream of chat completion chunks. Generated chat completion or a stream of chat completion chunks.
""" """
handler = self.chat_handler or llama_chat_format.get_chat_completion_handler( handler = self.chat_handler or self._chat_handlers.get(self.chat_format) or llama_chat_format.get_chat_completion_handler(
self.chat_format self.chat_format
) )
return handler( return handler(