This commit is contained in:
Andrei Betlen 2024-04-21 20:46:42 -04:00
commit b21ba0e2ac
4 changed files with 31 additions and 3 deletions

View file

@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
## [0.2.63]
- feat: Update llama.cpp to ggerganov/llama.cpp@0e4802b2ecbaab04b4f829fde4a3096ca19c84b5
- feat: Add stopping_criteria to ChatFormatter, allow stopping on arbitrary token ids, fixes llama3 instruct by @abetlen in cc81afebf04d26ca1ac3cf72f23f18da6ab58588
## [0.2.62] ## [0.2.62]
- feat: Update llama.cpp to ggerganov/llama.cpp@3b8f1ec4b18770531d0b1d792f3edf08254e4f0c - feat: Update llama.cpp to ggerganov/llama.cpp@3b8f1ec4b18770531d0b1d792f3edf08254e4f0c

View file

@ -1,4 +1,4 @@
from .llama_cpp import * from .llama_cpp import *
from .llama import * from .llama import *
__version__ = "0.2.62" __version__ = "0.2.63"

View file

@ -426,7 +426,10 @@ class Llama:
print(f"Using chat bos_token: {bos_token}", file=sys.stderr) print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
self.chat_handler = llama_chat_format.Jinja2ChatFormatter( self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
template=template, eos_token=eos_token, bos_token=bos_token template=template,
eos_token=eos_token,
bos_token=bos_token,
stop_token_ids=[eos_token_id],
).to_chat_handler() ).to_chat_handler()
if self.chat_format is None and self.chat_handler is None: if self.chat_format is None and self.chat_handler is None:

View file

@ -10,6 +10,9 @@ from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, P
import jinja2 import jinja2
import numpy as np
import numpy.typing as npt
import llama_cpp.llama as llama import llama_cpp.llama as llama
import llama_cpp.llama_types as llama_types import llama_cpp.llama_types as llama_types
import llama_cpp.llama_grammar as llama_grammar import llama_cpp.llama_grammar as llama_grammar
@ -150,6 +153,7 @@ class ChatFormatterResponse:
prompt: str prompt: str
stop: Optional[Union[str, List[str]]] = None stop: Optional[Union[str, List[str]]] = None
stopping_criteria: Optional[llama.StoppingCriteriaList] = None
class ChatFormatter(Protocol): class ChatFormatter(Protocol):
@ -173,12 +177,14 @@ class Jinja2ChatFormatter(ChatFormatter):
eos_token: str, eos_token: str,
bos_token: str, bos_token: str,
add_generation_prompt: bool = True, add_generation_prompt: bool = True,
stop_token_ids: Optional[List[int]] = None,
): ):
"""A chat formatter that uses jinja2 templates to format the prompt.""" """A chat formatter that uses jinja2 templates to format the prompt."""
self.template = template self.template = template
self.eos_token = eos_token self.eos_token = eos_token
self.bos_token = bos_token self.bos_token = bos_token
self.add_generation_prompt = add_generation_prompt self.add_generation_prompt = add_generation_prompt
self.stop_token_ids = set(stop_token_ids) if stop_token_ids is not None else None
self._environment = jinja2.Environment( self._environment = jinja2.Environment(
loader=jinja2.BaseLoader(), loader=jinja2.BaseLoader(),
@ -211,7 +217,16 @@ class Jinja2ChatFormatter(ChatFormatter):
tool_choice=tool_choice, tool_choice=tool_choice,
) )
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token]) stopping_criteria = None
if self.stop_token_ids is not None:
def stop_on_last_token(
tokens: npt.NDArray[np.intc],
logits: npt.NDArray[np.single]
) -> bool:
return tokens[-1] in self.stop_token_ids
stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria)
def to_chat_handler(self) -> LlamaChatCompletionHandler: def to_chat_handler(self) -> LlamaChatCompletionHandler:
return chat_formatter_to_chat_completion_handler(self) return chat_formatter_to_chat_completion_handler(self)
@ -533,6 +548,10 @@ def chat_formatter_to_chat_completion_handler(
rstop = result.stop if isinstance(result.stop, list) else [result.stop] rstop = result.stop if isinstance(result.stop, list) else [result.stop]
stop = stop + rstop stop = stop + rstop
stopping_criteria = None
if result.stopping_criteria is not None:
stopping_criteria = result.stopping_criteria
if response_format is not None and response_format["type"] == "json_object": if response_format is not None and response_format["type"] == "json_object":
grammar = _grammar_for_response_format(response_format, verbose=llama.verbose) grammar = _grammar_for_response_format(response_format, verbose=llama.verbose)
@ -598,6 +617,7 @@ def chat_formatter_to_chat_completion_handler(
mirostat_eta=mirostat_eta, mirostat_eta=mirostat_eta,
model=model, model=model,
logits_processor=logits_processor, logits_processor=logits_processor,
stopping_criteria=stopping_criteria,
grammar=grammar, grammar=grammar,
logit_bias=logit_bias, logit_bias=logit_bias,
) )