Add speculative decoding (#1120)
* Add draft model param to llama class, implement basic prompt lookup decoding draft model * Use samplingcontext for sampling * Use 1d array * Use draft model for sampling * Fix dumb mistake * Allow for later extensions to the LlamaDraftModel api * Cleanup * Adaptive candidate prediction * Update implementation to match hf transformers * Tuning * Fix bug where last token was not used for ngram prediction * Remove heuristic for num_pred_tokens (no benefit) * fix: n_candidates bug. * Add draft_model_num_pred_tokens server setting * Cleanup * Update README
This commit is contained in:
parent
71e3e4c435
commit
fb762a6041
6 changed files with 206 additions and 89 deletions
18
README.md
18
README.md
|
@ -378,6 +378,24 @@ Then you'll need to use a custom chat handler to load the clip model and process
|
|||
)
|
||||
```
|
||||
|
||||
### Speculative Decoding
|
||||
|
||||
`llama-cpp-python` supports speculative decoding which allows the model to generate completions based on a draft model.
|
||||
|
||||
The fastest way to use speculative decoding is through the `LlamaPromptLookupDecoding` class.
|
||||
|
||||
Just pass this as a draft model to the `Llama` class during initialization.
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
|
||||
|
||||
llama = Llama(
|
||||
model_path="path/to/model.gguf",
|
||||
draft_model=LlamaPromptLookupDecoding(num_pred_tokens=10) # num_pred_tokens is the number of tokens to predict 10 is the default and generally good for gpu, 2 performs better for cpu-only machines.
|
||||
)
|
||||
```
|
||||
|
||||
### Adjusting the Context Window
|
||||
|
||||
The context window of the Llama models determines the maximum number of tokens that can be processed at once. By default, this is set to 512 tokens, but can be adjusted based on your requirements.
|
||||
|
|
|
@ -30,6 +30,8 @@ from .llama_cache import (
|
|||
import llama_cpp.llama_cpp as llama_cpp
|
||||
import llama_cpp.llama_chat_format as llama_chat_format
|
||||
|
||||
from llama_cpp.llama_speculative import LlamaDraftModel
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
|
@ -39,6 +41,8 @@ from ._internals import (
|
|||
_LlamaContext, # type: ignore
|
||||
_LlamaBatch, # type: ignore
|
||||
_LlamaTokenDataArray, # type: ignore
|
||||
_LlamaSamplingParams, # type: ignore
|
||||
_LlamaSamplingContext, # type: ignore
|
||||
)
|
||||
|
||||
|
||||
|
@ -89,6 +93,8 @@ class Llama:
|
|||
# Chat Format Params
|
||||
chat_format: Optional[str] = None,
|
||||
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
|
||||
# Speculative Decoding
|
||||
draft_model: Optional[LlamaDraftModel] = None,
|
||||
# Misc
|
||||
verbose: bool = True,
|
||||
# Extra Params
|
||||
|
@ -152,6 +158,7 @@ class Llama:
|
|||
numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
|
||||
chat_format: String specifying the chat format to use when calling create_chat_completion.
|
||||
chat_handler: Optional chat handler to use when calling create_chat_completion.
|
||||
draft_model: Optional draft model to use for speculative decoding.
|
||||
verbose: Print verbose output to stderr.
|
||||
|
||||
Raises:
|
||||
|
@ -315,6 +322,8 @@ class Llama:
|
|||
self.chat_format = chat_format
|
||||
self.chat_handler = chat_handler
|
||||
|
||||
self.draft_model = draft_model
|
||||
|
||||
self._n_vocab = self.n_vocab()
|
||||
self._n_ctx = self.n_ctx()
|
||||
|
||||
|
@ -503,6 +512,7 @@ class Llama:
|
|||
penalize_nl: bool = True,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
grammar: Optional[LlamaGrammar] = None,
|
||||
idx: Optional[int] = None,
|
||||
):
|
||||
"""Sample a token from the model.
|
||||
|
||||
|
@ -517,77 +527,46 @@ class Llama:
|
|||
"""
|
||||
assert self._ctx is not None
|
||||
assert self.n_tokens > 0
|
||||
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
|
||||
0, self.last_n_tokens_size - self.n_tokens
|
||||
) + self._input_ids[-self.last_n_tokens_size :].tolist()
|
||||
last_n_tokens_size = len(last_n_tokens_data)
|
||||
n_vocab = self._n_vocab
|
||||
n_ctx = self._n_ctx
|
||||
top_k = n_vocab if top_k <= 0 else top_k
|
||||
last_n_tokens_size = n_ctx if last_n_tokens_size < 0 else last_n_tokens_size
|
||||
last_n_tokens_data_c = (llama_cpp.llama_token * last_n_tokens_size)(
|
||||
*last_n_tokens_data
|
||||
)
|
||||
logits: npt.NDArray[np.single] = self._scores[-1, :]
|
||||
|
||||
if idx is None:
|
||||
logits: npt.NDArray[np.single] = self._scores[-1, :]
|
||||
else:
|
||||
logits = self._scores[idx, :]
|
||||
|
||||
if logits_processor is not None:
|
||||
logits[:] = logits_processor(self._input_ids, logits)
|
||||
logits[:] = (
|
||||
logits_processor(self._input_ids, logits)
|
||||
if idx is None
|
||||
else logits_processor(self._input_ids[:idx], logits)
|
||||
)
|
||||
|
||||
nl_logit = logits[self._token_nl]
|
||||
self._candidates.copy_logits(logits)
|
||||
self._ctx.sample_repetition_penalties(
|
||||
candidates=self._candidates,
|
||||
last_tokens_data=last_n_tokens_data_c,
|
||||
penalty_last_n=last_n_tokens_size,
|
||||
sampling_params = _LlamaSamplingParams(
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
tfs_z=tfs_z,
|
||||
typical_p=typical_p,
|
||||
temp=temp,
|
||||
penalty_last_n=self.last_n_tokens_size,
|
||||
penalty_repeat=repeat_penalty,
|
||||
penalty_freq=frequency_penalty,
|
||||
penalty_present=presence_penalty,
|
||||
mirostat=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
penalize_nl=penalize_nl,
|
||||
)
|
||||
sampling_context = _LlamaSamplingContext(
|
||||
params=sampling_params,
|
||||
grammar=grammar,
|
||||
)
|
||||
sampling_context.prev = list(self.eval_tokens)
|
||||
id = sampling_context.sample(ctx_main=self._ctx, logits_array=logits)
|
||||
sampling_context.accept(
|
||||
ctx_main=self._ctx,
|
||||
id=id,
|
||||
apply_grammar=grammar is not None,
|
||||
)
|
||||
if not penalize_nl:
|
||||
self._candidates.candidates.data[self._token_nl].logit = llama_cpp.c_float(
|
||||
nl_logit
|
||||
)
|
||||
|
||||
if grammar is not None:
|
||||
self._ctx.sample_grammar(
|
||||
candidates=self._candidates,
|
||||
grammar=grammar,
|
||||
)
|
||||
|
||||
if temp < 0.0:
|
||||
self._ctx.sample_softmax(candidates=self._candidates)
|
||||
id = self._candidates.candidates.data[0].id
|
||||
elif temp == 0.0:
|
||||
id = self._ctx.sample_token_greedy(candidates=self._candidates)
|
||||
elif mirostat_mode == 1:
|
||||
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
|
||||
id = self._ctx.sample_token_mirostat(
|
||||
candidates=self._candidates,
|
||||
tau=mirostat_tau,
|
||||
eta=mirostat_eta,
|
||||
mu=ctypes.pointer(self._mirostat_mu),
|
||||
m=100,
|
||||
)
|
||||
elif mirostat_mode == 2:
|
||||
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
|
||||
id = self._ctx.sample_token_mirostat_v2(
|
||||
candidates=self._candidates,
|
||||
tau=mirostat_tau,
|
||||
eta=mirostat_eta,
|
||||
mu=ctypes.pointer(self._mirostat_mu),
|
||||
)
|
||||
else:
|
||||
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
|
||||
self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1)
|
||||
self._ctx.sample_typical(
|
||||
candidates=self._candidates, p=typical_p, min_keep=1
|
||||
)
|
||||
self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1)
|
||||
self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1)
|
||||
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
|
||||
id = self._ctx.sample_token(candidates=self._candidates)
|
||||
if grammar is not None:
|
||||
self._ctx.grammar_accept_token(grammar=grammar, token=id)
|
||||
return id
|
||||
|
||||
def generate(
|
||||
|
@ -656,34 +635,56 @@ class Llama:
|
|||
if grammar is not None:
|
||||
grammar.reset()
|
||||
|
||||
sample_idx = self.n_tokens + len(tokens) - 1
|
||||
tokens = list(tokens)
|
||||
|
||||
# Eval and sample
|
||||
while True:
|
||||
self.eval(tokens)
|
||||
token = self.sample(
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
temp=temp,
|
||||
repeat_penalty=repeat_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
tfs_z=tfs_z,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
penalize_nl=penalize_nl,
|
||||
)
|
||||
if stopping_criteria is not None and stopping_criteria(
|
||||
self._input_ids, self._scores[-1, :]
|
||||
):
|
||||
return
|
||||
tokens_or_none = yield token
|
||||
tokens = [token]
|
||||
if tokens_or_none is not None:
|
||||
tokens.extend(tokens_or_none)
|
||||
while sample_idx < self.n_tokens:
|
||||
token = self.sample(
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
temp=temp,
|
||||
repeat_penalty=repeat_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
tfs_z=tfs_z,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
penalize_nl=penalize_nl,
|
||||
idx=sample_idx,
|
||||
)
|
||||
|
||||
sample_idx += 1
|
||||
if stopping_criteria is not None and stopping_criteria(
|
||||
self._input_ids, self._scores[-1, :]
|
||||
):
|
||||
return
|
||||
tokens_or_none = yield token
|
||||
tokens.clear()
|
||||
tokens.append(token)
|
||||
if tokens_or_none is not None:
|
||||
tokens.extend(tokens_or_none)
|
||||
|
||||
if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]:
|
||||
self.n_tokens = sample_idx
|
||||
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
|
||||
break
|
||||
|
||||
if self.draft_model is not None:
|
||||
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
|
||||
draft_tokens = self.draft_model(self.input_ids[:self.n_tokens + len(tokens)])
|
||||
tokens.extend(
|
||||
draft_tokens.astype(int)[
|
||||
: self._n_ctx - self.n_tokens - len(tokens)
|
||||
]
|
||||
)
|
||||
|
||||
def create_embedding(
|
||||
self, input: Union[str, List[str]], model: Optional[str] = None
|
||||
|
|
64
llama_cpp/llama_speculative.py
Normal file
64
llama_cpp/llama_speculative.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import abc
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
|
||||
class LlamaDraftModel(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def __call__(
|
||||
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
|
||||
) -> npt.NDArray[np.intc]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class LlamaPromptLookupDecoding(LlamaDraftModel):
|
||||
"""Based on https://github.com/apoorvumang/prompt-lookup-decoding"""
|
||||
|
||||
def __init__(self, max_ngram_size: int = 2, num_pred_tokens: int = 10):
|
||||
self.max_ngram_size = max_ngram_size
|
||||
self.num_pred_tokens = num_pred_tokens
|
||||
|
||||
@staticmethod
|
||||
def find_candidate_pred_tokens(
|
||||
input_ids: npt.NDArray[np.intc],
|
||||
max_ngram_size: int,
|
||||
num_pred_tokens: int,
|
||||
):
|
||||
input_length = input_ids.shape[0]
|
||||
|
||||
for ngram_size in range(min(max_ngram_size, input_length - 1), 0, -1):
|
||||
# Create sliding windows of size ngram_size
|
||||
windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,))
|
||||
|
||||
# Convert ngram to an array for comparison
|
||||
ngram_array = input_ids[-ngram_size:]
|
||||
|
||||
# Find where the windows match the ngram
|
||||
matches = np.all(windows == ngram_array, axis=1)
|
||||
|
||||
# Get the indices of matches
|
||||
match_indices = np.nonzero(matches)[0]
|
||||
|
||||
# Iterate through match indices to find a valid continuation
|
||||
for idx in match_indices:
|
||||
start_idx = idx + ngram_size
|
||||
end_idx = start_idx + num_pred_tokens
|
||||
end_idx = min(end_idx, input_length)
|
||||
|
||||
if start_idx < end_idx:
|
||||
return input_ids[start_idx:end_idx]
|
||||
|
||||
# If no match is found, return an empty array
|
||||
return np.array([], dtype=np.intc)
|
||||
|
||||
def __call__(
|
||||
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
|
||||
) -> npt.NDArray[np.intc]:
|
||||
return self.find_candidate_pred_tokens(
|
||||
input_ids=input_ids,
|
||||
max_ngram_size=self.max_ngram_size,
|
||||
num_pred_tokens=self.num_pred_tokens,
|
||||
)
|
|
@ -5,6 +5,7 @@ import json
|
|||
from typing import Dict, Optional, Union, List
|
||||
|
||||
import llama_cpp
|
||||
import llama_cpp.llama_speculative as llama_speculative
|
||||
|
||||
from llama_cpp.server.settings import ModelSettings
|
||||
|
||||
|
@ -92,6 +93,12 @@ class LlamaProxy:
|
|||
)
|
||||
)
|
||||
|
||||
draft_model = None
|
||||
if settings.draft_model is not None:
|
||||
draft_model = llama_speculative.LlamaPromptLookupDecoding(
|
||||
num_pred_tokens=settings.draft_model_num_pred_tokens
|
||||
)
|
||||
|
||||
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
|
||||
if settings.kv_overrides is not None:
|
||||
assert isinstance(settings.kv_overrides, list)
|
||||
|
@ -147,6 +154,8 @@ class LlamaProxy:
|
|||
# Chat Format Params
|
||||
chat_format=settings.chat_format,
|
||||
chat_handler=chat_handler,
|
||||
# Speculative Decoding
|
||||
draft_model=draft_model,
|
||||
# Misc
|
||||
verbose=settings.verbose,
|
||||
)
|
||||
|
|
|
@ -143,6 +143,15 @@ class ModelSettings(BaseSettings):
|
|||
default=None,
|
||||
description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().",
|
||||
)
|
||||
# Speculative Decoding
|
||||
draft_model: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Method to use for speculative decoding. One of (prompt-lookup-decoding).",
|
||||
)
|
||||
draft_model_num_pred_tokens: int = Field(
|
||||
default=10,
|
||||
description="Number of tokens to predict using the draft model.",
|
||||
)
|
||||
# Misc
|
||||
verbose: bool = Field(
|
||||
default=True, description="Whether to print debug information."
|
||||
|
|
16
tests/test_llama_speculative.py
Normal file
16
tests/test_llama_speculative.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
import numpy as np
|
||||
|
||||
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
|
||||
|
||||
def test_find_candidate_pred_tokens():
|
||||
find_candidate_pred_tokens = LlamaPromptLookupDecoding.find_candidate_pred_tokens
|
||||
|
||||
# Test Case 1: Matching ngram is found
|
||||
input_ids1 = np.array([1, 2, 3, 1, 2, 3, 1, 2, 3])
|
||||
result1 = find_candidate_pred_tokens(input_ids1, max_ngram_size=3, num_pred_tokens=2)
|
||||
assert np.array_equal(result1, np.array([1, 2]))
|
||||
|
||||
# Test Case 2: Matching ngram is not found
|
||||
input_ids2 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
result2 = find_candidate_pred_tokens(input_ids2, max_ngram_size=3, num_pred_tokens=2)
|
||||
assert np.array_equal(result2, np.array([]))
|
Loading…
Reference in a new issue