feat: Add stopping_criteria to ChatFormatter, allow stopping on arbitrary token ids, fixes llama3 instruct
This commit is contained in:
parent
d17c1887a3
commit
cc81afebf0
2 changed files with 25 additions and 2 deletions
|
@ -426,7 +426,10 @@ class Llama:
|
||||||
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
|
print(f"Using chat bos_token: {bos_token}", file=sys.stderr)
|
||||||
|
|
||||||
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
|
self.chat_handler = llama_chat_format.Jinja2ChatFormatter(
|
||||||
template=template, eos_token=eos_token, bos_token=bos_token
|
template=template,
|
||||||
|
eos_token=eos_token,
|
||||||
|
bos_token=bos_token,
|
||||||
|
stop_token_ids=[eos_token_id],
|
||||||
).to_chat_handler()
|
).to_chat_handler()
|
||||||
|
|
||||||
if self.chat_format is None and self.chat_handler is None:
|
if self.chat_format is None and self.chat_handler is None:
|
||||||
|
|
|
@ -10,6 +10,9 @@ from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union, P
|
||||||
|
|
||||||
import jinja2
|
import jinja2
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
import llama_cpp.llama as llama
|
import llama_cpp.llama as llama
|
||||||
import llama_cpp.llama_types as llama_types
|
import llama_cpp.llama_types as llama_types
|
||||||
import llama_cpp.llama_grammar as llama_grammar
|
import llama_cpp.llama_grammar as llama_grammar
|
||||||
|
@ -150,6 +153,7 @@ class ChatFormatterResponse:
|
||||||
|
|
||||||
prompt: str
|
prompt: str
|
||||||
stop: Optional[Union[str, List[str]]] = None
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
|
stopping_criteria: Optional[llama.StoppingCriteriaList] = None
|
||||||
|
|
||||||
|
|
||||||
class ChatFormatter(Protocol):
|
class ChatFormatter(Protocol):
|
||||||
|
@ -173,12 +177,14 @@ class Jinja2ChatFormatter(ChatFormatter):
|
||||||
eos_token: str,
|
eos_token: str,
|
||||||
bos_token: str,
|
bos_token: str,
|
||||||
add_generation_prompt: bool = True,
|
add_generation_prompt: bool = True,
|
||||||
|
stop_token_ids: Optional[List[int]] = None,
|
||||||
):
|
):
|
||||||
"""A chat formatter that uses jinja2 templates to format the prompt."""
|
"""A chat formatter that uses jinja2 templates to format the prompt."""
|
||||||
self.template = template
|
self.template = template
|
||||||
self.eos_token = eos_token
|
self.eos_token = eos_token
|
||||||
self.bos_token = bos_token
|
self.bos_token = bos_token
|
||||||
self.add_generation_prompt = add_generation_prompt
|
self.add_generation_prompt = add_generation_prompt
|
||||||
|
self.stop_token_ids = set(stop_token_ids) if stop_token_ids is not None else None
|
||||||
|
|
||||||
self._environment = jinja2.Environment(
|
self._environment = jinja2.Environment(
|
||||||
loader=jinja2.BaseLoader(),
|
loader=jinja2.BaseLoader(),
|
||||||
|
@ -211,7 +217,16 @@ class Jinja2ChatFormatter(ChatFormatter):
|
||||||
tool_choice=tool_choice,
|
tool_choice=tool_choice,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token])
|
stopping_criteria = None
|
||||||
|
if self.stop_token_ids is not None:
|
||||||
|
def stop_on_last_token(
|
||||||
|
tokens: npt.NDArray[np.intc],
|
||||||
|
logits: npt.NDArray[np.single]
|
||||||
|
) -> bool:
|
||||||
|
return tokens[-1] in self.stop_token_ids
|
||||||
|
stopping_criteria = llama.StoppingCriteriaList([stop_on_last_token])
|
||||||
|
|
||||||
|
return ChatFormatterResponse(prompt=prompt, stop=[self.eos_token], stopping_criteria=stopping_criteria)
|
||||||
|
|
||||||
def to_chat_handler(self) -> LlamaChatCompletionHandler:
|
def to_chat_handler(self) -> LlamaChatCompletionHandler:
|
||||||
return chat_formatter_to_chat_completion_handler(self)
|
return chat_formatter_to_chat_completion_handler(self)
|
||||||
|
@ -533,6 +548,10 @@ def chat_formatter_to_chat_completion_handler(
|
||||||
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
|
rstop = result.stop if isinstance(result.stop, list) else [result.stop]
|
||||||
stop = stop + rstop
|
stop = stop + rstop
|
||||||
|
|
||||||
|
stopping_criteria = None
|
||||||
|
if result.stopping_criteria is not None:
|
||||||
|
stopping_criteria = result.stopping_criteria
|
||||||
|
|
||||||
if response_format is not None and response_format["type"] == "json_object":
|
if response_format is not None and response_format["type"] == "json_object":
|
||||||
grammar = _grammar_for_response_format(response_format, verbose=llama.verbose)
|
grammar = _grammar_for_response_format(response_format, verbose=llama.verbose)
|
||||||
|
|
||||||
|
@ -598,6 +617,7 @@ def chat_formatter_to_chat_completion_handler(
|
||||||
mirostat_eta=mirostat_eta,
|
mirostat_eta=mirostat_eta,
|
||||||
model=model,
|
model=model,
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
|
stopping_criteria=stopping_criteria,
|
||||||
grammar=grammar,
|
grammar=grammar,
|
||||||
logit_bias=logit_bias,
|
logit_bias=logit_bias,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Reference in a new issue