Add speculative decoding (#1120)

* Add draft model param to llama class, implement basic prompt lookup decoding draft model

* Use samplingcontext for sampling

* Use 1d array

* Use draft model for sampling

* Fix dumb mistake

* Allow for later extensions to the LlamaDraftModel api

* Cleanup

* Adaptive candidate prediction

* Update implementation to match hf transformers

* Tuning

* Fix bug where last token was not used for ngram prediction

* Remove heuristic for num_pred_tokens (no benefit)

* fix: n_candidates bug.

* Add draft_model_num_pred_tokens server setting

* Cleanup

* Update README
This commit is contained in:
Andrei 2024-01-31 14:08:14 -05:00 committed by GitHub
parent 71e3e4c435
commit fb762a6041
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 206 additions and 89 deletions

View file

@ -378,6 +378,24 @@ Then you'll need to use a custom chat handler to load the clip model and process
) )
``` ```
### Speculative Decoding
`llama-cpp-python` supports speculative decoding which allows the model to generate completions based on a draft model.
The fastest way to use speculative decoding is through the `LlamaPromptLookupDecoding` class.
Just pass this as a draft model to the `Llama` class during initialization.
```python
from llama_cpp import Llama
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
llama = Llama(
model_path="path/to/model.gguf",
draft_model=LlamaPromptLookupDecoding(num_pred_tokens=10) # num_pred_tokens is the number of tokens to predict 10 is the default and generally good for gpu, 2 performs better for cpu-only machines.
)
```
### Adjusting the Context Window ### Adjusting the Context Window
The context window of the Llama models determines the maximum number of tokens that can be processed at once. By default, this is set to 512 tokens, but can be adjusted based on your requirements. The context window of the Llama models determines the maximum number of tokens that can be processed at once. By default, this is set to 512 tokens, but can be adjusted based on your requirements.

View file

@ -30,6 +30,8 @@ from .llama_cache import (
import llama_cpp.llama_cpp as llama_cpp import llama_cpp.llama_cpp as llama_cpp
import llama_cpp.llama_chat_format as llama_chat_format import llama_cpp.llama_chat_format as llama_chat_format
from llama_cpp.llama_speculative import LlamaDraftModel
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
@ -39,6 +41,8 @@ from ._internals import (
_LlamaContext, # type: ignore _LlamaContext, # type: ignore
_LlamaBatch, # type: ignore _LlamaBatch, # type: ignore
_LlamaTokenDataArray, # type: ignore _LlamaTokenDataArray, # type: ignore
_LlamaSamplingParams, # type: ignore
_LlamaSamplingContext, # type: ignore
) )
@ -89,6 +93,8 @@ class Llama:
# Chat Format Params # Chat Format Params
chat_format: Optional[str] = None, chat_format: Optional[str] = None,
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None, chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
# Speculative Decoding
draft_model: Optional[LlamaDraftModel] = None,
# Misc # Misc
verbose: bool = True, verbose: bool = True,
# Extra Params # Extra Params
@ -152,6 +158,7 @@ class Llama:
numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init) numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
chat_format: String specifying the chat format to use when calling create_chat_completion. chat_format: String specifying the chat format to use when calling create_chat_completion.
chat_handler: Optional chat handler to use when calling create_chat_completion. chat_handler: Optional chat handler to use when calling create_chat_completion.
draft_model: Optional draft model to use for speculative decoding.
verbose: Print verbose output to stderr. verbose: Print verbose output to stderr.
Raises: Raises:
@ -315,6 +322,8 @@ class Llama:
self.chat_format = chat_format self.chat_format = chat_format
self.chat_handler = chat_handler self.chat_handler = chat_handler
self.draft_model = draft_model
self._n_vocab = self.n_vocab() self._n_vocab = self.n_vocab()
self._n_ctx = self.n_ctx() self._n_ctx = self.n_ctx()
@ -503,6 +512,7 @@ class Llama:
penalize_nl: bool = True, penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None, grammar: Optional[LlamaGrammar] = None,
idx: Optional[int] = None,
): ):
"""Sample a token from the model. """Sample a token from the model.
@ -517,77 +527,46 @@ class Llama:
""" """
assert self._ctx is not None assert self._ctx is not None
assert self.n_tokens > 0 assert self.n_tokens > 0
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
0, self.last_n_tokens_size - self.n_tokens if idx is None:
) + self._input_ids[-self.last_n_tokens_size :].tolist()
last_n_tokens_size = len(last_n_tokens_data)
n_vocab = self._n_vocab
n_ctx = self._n_ctx
top_k = n_vocab if top_k <= 0 else top_k
last_n_tokens_size = n_ctx if last_n_tokens_size < 0 else last_n_tokens_size
last_n_tokens_data_c = (llama_cpp.llama_token * last_n_tokens_size)(
*last_n_tokens_data
)
logits: npt.NDArray[np.single] = self._scores[-1, :] logits: npt.NDArray[np.single] = self._scores[-1, :]
else:
logits = self._scores[idx, :]
if logits_processor is not None: if logits_processor is not None:
logits[:] = logits_processor(self._input_ids, logits) logits[:] = (
logits_processor(self._input_ids, logits)
if idx is None
else logits_processor(self._input_ids[:idx], logits)
)
nl_logit = logits[self._token_nl] sampling_params = _LlamaSamplingParams(
self._candidates.copy_logits(logits) top_k=top_k,
self._ctx.sample_repetition_penalties( top_p=top_p,
candidates=self._candidates, min_p=min_p,
last_tokens_data=last_n_tokens_data_c, tfs_z=tfs_z,
penalty_last_n=last_n_tokens_size, typical_p=typical_p,
temp=temp,
penalty_last_n=self.last_n_tokens_size,
penalty_repeat=repeat_penalty, penalty_repeat=repeat_penalty,
penalty_freq=frequency_penalty, penalty_freq=frequency_penalty,
penalty_present=presence_penalty, penalty_present=presence_penalty,
mirostat=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
penalize_nl=penalize_nl,
) )
if not penalize_nl: sampling_context = _LlamaSamplingContext(
self._candidates.candidates.data[self._token_nl].logit = llama_cpp.c_float( params=sampling_params,
nl_logit
)
if grammar is not None:
self._ctx.sample_grammar(
candidates=self._candidates,
grammar=grammar, grammar=grammar,
) )
sampling_context.prev = list(self.eval_tokens)
if temp < 0.0: id = sampling_context.sample(ctx_main=self._ctx, logits_array=logits)
self._ctx.sample_softmax(candidates=self._candidates) sampling_context.accept(
id = self._candidates.candidates.data[0].id ctx_main=self._ctx,
elif temp == 0.0: id=id,
id = self._ctx.sample_token_greedy(candidates=self._candidates) apply_grammar=grammar is not None,
elif mirostat_mode == 1:
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
id = self._ctx.sample_token_mirostat(
candidates=self._candidates,
tau=mirostat_tau,
eta=mirostat_eta,
mu=ctypes.pointer(self._mirostat_mu),
m=100,
) )
elif mirostat_mode == 2:
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
id = self._ctx.sample_token_mirostat_v2(
candidates=self._candidates,
tau=mirostat_tau,
eta=mirostat_eta,
mu=ctypes.pointer(self._mirostat_mu),
)
else:
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1)
self._ctx.sample_typical(
candidates=self._candidates, p=typical_p, min_keep=1
)
self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1)
self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1)
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
id = self._ctx.sample_token(candidates=self._candidates)
if grammar is not None:
self._ctx.grammar_accept_token(grammar=grammar, token=id)
return id return id
def generate( def generate(
@ -656,9 +635,13 @@ class Llama:
if grammar is not None: if grammar is not None:
grammar.reset() grammar.reset()
sample_idx = self.n_tokens + len(tokens) - 1
tokens = list(tokens)
# Eval and sample # Eval and sample
while True: while True:
self.eval(tokens) self.eval(tokens)
while sample_idx < self.n_tokens:
token = self.sample( token = self.sample(
top_k=top_k, top_k=top_k,
top_p=top_p, top_p=top_p,
@ -675,16 +658,34 @@ class Llama:
logits_processor=logits_processor, logits_processor=logits_processor,
grammar=grammar, grammar=grammar,
penalize_nl=penalize_nl, penalize_nl=penalize_nl,
idx=sample_idx,
) )
sample_idx += 1
if stopping_criteria is not None and stopping_criteria( if stopping_criteria is not None and stopping_criteria(
self._input_ids, self._scores[-1, :] self._input_ids, self._scores[-1, :]
): ):
return return
tokens_or_none = yield token tokens_or_none = yield token
tokens = [token] tokens.clear()
tokens.append(token)
if tokens_or_none is not None: if tokens_or_none is not None:
tokens.extend(tokens_or_none) tokens.extend(tokens_or_none)
if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]:
self.n_tokens = sample_idx
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
break
if self.draft_model is not None:
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
draft_tokens = self.draft_model(self.input_ids[:self.n_tokens + len(tokens)])
tokens.extend(
draft_tokens.astype(int)[
: self._n_ctx - self.n_tokens - len(tokens)
]
)
def create_embedding( def create_embedding(
self, input: Union[str, List[str]], model: Optional[str] = None self, input: Union[str, List[str]], model: Optional[str] = None
) -> CreateEmbeddingResponse: ) -> CreateEmbeddingResponse:

View file

@ -0,0 +1,64 @@
import abc
from typing import Any
import numpy as np
import numpy.typing as npt
class LlamaDraftModel(abc.ABC):
@abc.abstractmethod
def __call__(
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
) -> npt.NDArray[np.intc]:
raise NotImplementedError()
class LlamaPromptLookupDecoding(LlamaDraftModel):
"""Based on https://github.com/apoorvumang/prompt-lookup-decoding"""
def __init__(self, max_ngram_size: int = 2, num_pred_tokens: int = 10):
self.max_ngram_size = max_ngram_size
self.num_pred_tokens = num_pred_tokens
@staticmethod
def find_candidate_pred_tokens(
input_ids: npt.NDArray[np.intc],
max_ngram_size: int,
num_pred_tokens: int,
):
input_length = input_ids.shape[0]
for ngram_size in range(min(max_ngram_size, input_length - 1), 0, -1):
# Create sliding windows of size ngram_size
windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,))
# Convert ngram to an array for comparison
ngram_array = input_ids[-ngram_size:]
# Find where the windows match the ngram
matches = np.all(windows == ngram_array, axis=1)
# Get the indices of matches
match_indices = np.nonzero(matches)[0]
# Iterate through match indices to find a valid continuation
for idx in match_indices:
start_idx = idx + ngram_size
end_idx = start_idx + num_pred_tokens
end_idx = min(end_idx, input_length)
if start_idx < end_idx:
return input_ids[start_idx:end_idx]
# If no match is found, return an empty array
return np.array([], dtype=np.intc)
def __call__(
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
) -> npt.NDArray[np.intc]:
return self.find_candidate_pred_tokens(
input_ids=input_ids,
max_ngram_size=self.max_ngram_size,
num_pred_tokens=self.num_pred_tokens,
)

View file

@ -5,6 +5,7 @@ import json
from typing import Dict, Optional, Union, List from typing import Dict, Optional, Union, List
import llama_cpp import llama_cpp
import llama_cpp.llama_speculative as llama_speculative
from llama_cpp.server.settings import ModelSettings from llama_cpp.server.settings import ModelSettings
@ -92,6 +93,12 @@ class LlamaProxy:
) )
) )
draft_model = None
if settings.draft_model is not None:
draft_model = llama_speculative.LlamaPromptLookupDecoding(
num_pred_tokens=settings.draft_model_num_pred_tokens
)
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
if settings.kv_overrides is not None: if settings.kv_overrides is not None:
assert isinstance(settings.kv_overrides, list) assert isinstance(settings.kv_overrides, list)
@ -147,6 +154,8 @@ class LlamaProxy:
# Chat Format Params # Chat Format Params
chat_format=settings.chat_format, chat_format=settings.chat_format,
chat_handler=chat_handler, chat_handler=chat_handler,
# Speculative Decoding
draft_model=draft_model,
# Misc # Misc
verbose=settings.verbose, verbose=settings.verbose,
) )

View file

@ -143,6 +143,15 @@ class ModelSettings(BaseSettings):
default=None, default=None,
description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().", description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().",
) )
# Speculative Decoding
draft_model: Optional[str] = Field(
default=None,
description="Method to use for speculative decoding. One of (prompt-lookup-decoding).",
)
draft_model_num_pred_tokens: int = Field(
default=10,
description="Number of tokens to predict using the draft model.",
)
# Misc # Misc
verbose: bool = Field( verbose: bool = Field(
default=True, description="Whether to print debug information." default=True, description="Whether to print debug information."

View file

@ -0,0 +1,16 @@
import numpy as np
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
def test_find_candidate_pred_tokens():
find_candidate_pred_tokens = LlamaPromptLookupDecoding.find_candidate_pred_tokens
# Test Case 1: Matching ngram is found
input_ids1 = np.array([1, 2, 3, 1, 2, 3, 1, 2, 3])
result1 = find_candidate_pred_tokens(input_ids1, max_ngram_size=3, num_pred_tokens=2)
assert np.array_equal(result1, np.array([1, 2]))
# Test Case 2: Matching ngram is not found
input_ids2 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
result2 = find_candidate_pred_tokens(input_ids2, max_ngram_size=3, num_pred_tokens=2)
assert np.array_equal(result2, np.array([]))