From cc81afebf04d26ca1ac3cf72f23f18da6ab58588 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sat, 20 Apr 2024 00:00:53 -0400 Subject: [PATCH] feat: Add stopping_criteria to ChatFormatter, allow stopping on arbitrary token ids, fixes llama3 instruct --- llama_cpp/llama.py | 5 ++++- llama_cpp/llama_chat_format.py | 22 +++++++++++++++++++++- 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 5a0111b..818be82 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -426,7 +426,10 @@ class Llama: print(f"Using chat bos_token: {bos_token}", file=sys.stderr) self.chat_handler = llama_chat_format.Jinja2ChatFormatter( - template=template, eos_token=eos_token, bos_token=bos_token + template=template, + eos_token=eos_token, + bos_token=bos_token, + stop_token_ids=[eos_token_id], ).to_chat_handler() if self.chat_format is None and self.chat_handler is None: diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index eb98cbf..189ccb0 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -10,6 +10,9 @@ from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, P import jinja2 +import numpy as np +import numpy.typing as npt + import llama_cpp.llama as llama import llama_cpp.llama_types as llama_types import llama_cpp.llama_grammar as llama_grammar @@ -150,6 +153,7 @@ class ChatFormatterResponse: prompt: str stop: Optional[Union[str, List[str]]] = None + stopping_criteria: Optional[llama.StoppingCriteriaList] = None class ChatFormatter(Protocol): @@ -173,12 +177,14 @@ class Jinja2ChatFormatter(ChatFormatter): eos_token: str, bos_token: str, add_generation_prompt: bool = True, + stop_token_ids: Optional[List[int]] = None, ): """A chat formatter that uses jinja2 templates to format the prompt.""" self.template = template self.eos_token = eos_token self.bos_token = bos_token self.add_generation_prompt = add_generation_prompt + self.stop_token_ids = set(stop_token_ids) if stop_token_ids is not None else None self._environment = jinja2.Environment( loader=jinja2.BaseLoader(), @@ -211,7 +217,16 @@ class Jinja2ChatFormatter(ChatFormatter): tool_choice=tool_choice, ) - return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token]) + stopping_criteria = None + if self.stop_token_ids is not None: + def stop_on_last_token( + tokens: npt.NDArray[np.intc], + logits: npt.NDArray[np.single] + ) -> bool: + return tokens[-1] in self.stop_token_ids + stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token]) + + return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria) def to_chat_handler(self) -> LlamaChatCompletionHandler: return chat_formatter_to_chat_completion_handler(self) @@ -533,6 +548,10 @@ def chat_formatter_to_chat_completion_handler( rstop = result.stop if isinstance(result.stop, list) else [result.stop] stop = stop + rstop + stopping_criteria = None + if result.stopping_criteria is not None: + stopping_criteria = result.stopping_criteria + if response_format is not None and response_format["type"] == "json_object": grammar = _grammar_for_response_format(response_format, verbose=llama.verbose) @@ -598,6 +617,7 @@ def chat_formatter_to_chat_completion_handler( mirostat_eta=mirostat_eta, model=model, logits_processor=logits_processor, + stopping_criteria=stopping_criteria, grammar=grammar, logit_bias=logit_bias, )