From 5ab40e6167e64ecb06c2a4d8ae798391bf9a5893 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Thu, 9 May 2024 15:49:09 +0200 Subject: [PATCH] feat: Support multiple chat templates - step 1 (#1396) * Support multiple chat templates - step 1 As a first step, allow user to to select template from metadata with chat_format parameter in the form of `chat_template.name`. * register chat templates to self.chat_formats instead of globally * Don't expose internal chat handlers yet --------- Co-authored-by: Andrei --- llama_cpp/llama.py | 50 ++++++++++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 5acc112..4212669 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -378,6 +378,7 @@ class Llama: self.chat_format = chat_format self.chat_handler = chat_handler + self._chat_handlers: Dict[str, llama_chat_format.LlamaChatCompletionHandler] = {} self.draft_model = draft_model @@ -409,10 +410,33 @@ class Llama: if self.verbose: print(f"Model metadata: {self.metadata}", file=sys.stderr) + eos_token_id = int(self.metadata.get("tokenizer.ggml.eos_token_id", self.token_eos())) + bos_token_id = int(self.metadata.get("tokenizer.ggml.bos_token_id", self.token_bos())) + + eos_token = self._model.token_get_text(eos_token_id) + bos_token = self._model.token_get_text(bos_token_id) + + # Unfortunately the llama.cpp API does not return metadata arrays, so we can't get template names from tokenizer.chat_templates + template_choices = dict((name[10:], template) for name, template in self.metadata.items() if name.startswith("tokenizer.chat_template.")) + + if "tokenizer.chat_template" in self.metadata: + template_choices["chat_template.default"] = self.metadata["tokenizer.chat_template"] + + if self.verbose and template_choices: + print(f"Available chat formats from metadata: {', '.join(template_choices.keys())}", file=sys.stderr) + + for name, template in template_choices.items(): + self._chat_handlers[name] = llama_chat_format.Jinja2ChatFormatter( + template=template, + eos_token=eos_token, + bos_token=bos_token, + stop_token_ids=[eos_token_id], + ).to_chat_handler() + if ( self.chat_format is None and self.chat_handler is None - and "tokenizer.chat_template" in self.metadata + and "chat_template.default" in template_choices ): chat_format = llama_chat_format.guess_chat_format_from_gguf_metadata( self.metadata @@ -423,30 +447,12 @@ class Llama: if self.verbose: print(f"Guessed chat format: {chat_format}", file=sys.stderr) else: - template = self.metadata["tokenizer.chat_template"] - try: - eos_token_id = int(self.metadata["tokenizer.ggml.eos_token_id"]) - except: - eos_token_id = self.token_eos() - try: - bos_token_id = int(self.metadata["tokenizer.ggml.bos_token_id"]) - except: - bos_token_id = self.token_bos() - - eos_token = self._model.token_get_text(eos_token_id) - bos_token = self._model.token_get_text(bos_token_id) - if self.verbose: - print(f"Using gguf chat template: {template}", file=sys.stderr) + print(f"Using gguf chat template: {template_choices['chat_template.default']}", file=sys.stderr) print(f"Using chat eos_token: {eos_token}", file=sys.stderr) print(f"Using chat bos_token: {bos_token}", file=sys.stderr) - self.chat_handler = llama_chat_format.Jinja2ChatFormatter( - template=template, - eos_token=eos_token, - bos_token=bos_token, - stop_token_ids=[eos_token_id], - ).to_chat_handler() + self.chat_format = "chat_template.default" if self.chat_format is None and self.chat_handler is None: self.chat_format = "llama-2" @@ -1719,7 +1725,7 @@ class Llama: Returns: Generated chat completion or a stream of chat completion chunks. """ - handler = self.chat_handler or llama_chat_format.get_chat_completion_handler( + handler = self.chat_handler or self._chat_handlers.get(self.chat_format) or llama_chat_format.get_chat_completion_handler( self.chat_format ) return handler(