Compare commits
11 commits
cd66f3cfb4
...
b342398804
Author | SHA1 | Date | |
---|---|---|---|
b342398804 | |||
|
310fbf4e49 | ||
|
59760c85ed | ||
|
3553b14670 | ||
|
7467f129e5 | ||
|
bebfba0f08 | ||
|
8a5911bd5d | ||
|
de526d0214 | ||
|
3322eadbf3 | ||
|
a8cb34eacd | ||
|
fb762a6041 |
12 changed files with 277 additions and 127 deletions
|
@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
## [Unreleased]
|
||||
|
||||
## [0.2.38]
|
||||
|
||||
- feat: Update llama.cpp to ggerganov/llama.cpp@1cfb5372cf5707c8ec6dde7c874f4a44a6c4c915
|
||||
- feat: Add speculative decoding by @abetlen in #1120
|
||||
- fix: Pass raise_exception and add_generation_prompt to jinja2 chat template 078cca0361bf5a94d2cf52ed04980d20e32d6f95
|
||||
|
||||
## [0.2.37]
|
||||
|
||||
- feat: Update llama.cpp to ggerganov/llama.cpp@fea4fd4ba7f6b754ac795387b275e1a014a77bde
|
||||
|
|
18
README.md
18
README.md
|
@ -378,6 +378,24 @@ Then you'll need to use a custom chat handler to load the clip model and process
|
|||
)
|
||||
```
|
||||
|
||||
### Speculative Decoding
|
||||
|
||||
`llama-cpp-python` supports speculative decoding which allows the model to generate completions based on a draft model.
|
||||
|
||||
The fastest way to use speculative decoding is through the `LlamaPromptLookupDecoding` class.
|
||||
|
||||
Just pass this as a draft model to the `Llama` class during initialization.
|
||||
|
||||
```python
|
||||
from llama_cpp import Llama
|
||||
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
|
||||
|
||||
llama = Llama(
|
||||
model_path="path/to/model.gguf",
|
||||
draft_model=LlamaPromptLookupDecoding(num_pred_tokens=10) # num_pred_tokens is the number of tokens to predict 10 is the default and generally good for gpu, 2 performs better for cpu-only machines.
|
||||
)
|
||||
```
|
||||
|
||||
### Adjusting the Context Window
|
||||
|
||||
The context window of the Llama models determines the maximum number of tokens that can be processed at once. By default, this is set to 512 tokens, but can be adjusted based on your requirements.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from .llama_cpp import *
|
||||
from .llama import *
|
||||
|
||||
__version__ = "0.2.37"
|
||||
__version__ = "0.2.38"
|
|
@ -18,8 +18,6 @@ from .llama_grammar import LlamaGrammar
|
|||
|
||||
import llama_cpp.llama_cpp as llama_cpp
|
||||
|
||||
from ._utils import suppress_stdout_stderr
|
||||
|
||||
|
||||
# Python wrappers over llama.h structs
|
||||
|
||||
|
@ -30,7 +28,6 @@ class _LlamaModel:
|
|||
|
||||
_llama_free_model = None
|
||||
# NOTE: this must be "saved" here to avoid exceptions when calling __del__
|
||||
_suppress_stdout_stderr = suppress_stdout_stderr
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -48,16 +45,14 @@ class _LlamaModel:
|
|||
if not os.path.exists(path_model):
|
||||
raise ValueError(f"Model path does not exist: {path_model}")
|
||||
|
||||
with self._suppress_stdout_stderr(disable=self.verbose):
|
||||
self.model = llama_cpp.llama_load_model_from_file(
|
||||
self.path_model.encode("utf-8"), self.params
|
||||
)
|
||||
self.model = llama_cpp.llama_load_model_from_file(
|
||||
self.path_model.encode("utf-8"), self.params
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
with self._suppress_stdout_stderr(disable=self.verbose):
|
||||
if self.model is not None and self._llama_free_model is not None:
|
||||
self._llama_free_model(self.model)
|
||||
self.model = None
|
||||
if self.model is not None and self._llama_free_model is not None:
|
||||
self._llama_free_model(self.model)
|
||||
self.model = None
|
||||
|
||||
def vocab_type(self) -> int:
|
||||
assert self.model is not None
|
||||
|
@ -240,8 +235,6 @@ class _LlamaContext:
|
|||
NOTE: For stability it's recommended you use the Llama class instead."""
|
||||
|
||||
_llama_free = None
|
||||
# NOTE: this must be "saved" here to avoid exceptions when calling __del__
|
||||
_suppress_stdout_stderr = suppress_stdout_stderr
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -256,16 +249,16 @@ class _LlamaContext:
|
|||
|
||||
self._llama_free = llama_cpp._lib.llama_free # type: ignore
|
||||
|
||||
with self._suppress_stdout_stderr(disable=self.verbose):
|
||||
self.ctx = llama_cpp.llama_new_context_with_model(
|
||||
self.model.model, self.params
|
||||
)
|
||||
assert self.model.model is not None
|
||||
|
||||
self.ctx = llama_cpp.llama_new_context_with_model(
|
||||
self.model.model, self.params
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
with self._suppress_stdout_stderr(disable=self.verbose):
|
||||
if self.ctx is not None and self._llama_free is not None:
|
||||
self._llama_free(self.ctx)
|
||||
self.ctx = None
|
||||
if self.ctx is not None and self._llama_free is not None:
|
||||
self._llama_free(self.ctx)
|
||||
self.ctx = None
|
||||
|
||||
def n_ctx(self) -> int:
|
||||
assert self.ctx is not None
|
||||
|
@ -493,8 +486,6 @@ class _LlamaContext:
|
|||
|
||||
class _LlamaBatch:
|
||||
_llama_batch_free = None
|
||||
# NOTE: this must be "saved" here to avoid exceptions when calling __del__
|
||||
_suppress_stdout_stderr = suppress_stdout_stderr
|
||||
|
||||
def __init__(
|
||||
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
|
||||
|
@ -506,16 +497,14 @@ class _LlamaBatch:
|
|||
|
||||
self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore
|
||||
|
||||
with self._suppress_stdout_stderr(disable=self.verbose):
|
||||
self.batch = llama_cpp.llama_batch_init(
|
||||
self.n_tokens, self.embd, self.n_seq_max
|
||||
)
|
||||
self.batch = llama_cpp.llama_batch_init(
|
||||
self.n_tokens, self.embd, self.n_seq_max
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
with self._suppress_stdout_stderr(disable=self.verbose):
|
||||
if self.batch is not None and self._llama_batch_free is not None:
|
||||
self._llama_batch_free(self.batch)
|
||||
self.batch = None
|
||||
if self.batch is not None and self._llama_batch_free is not None:
|
||||
self._llama_batch_free(self.batch)
|
||||
self.batch = None
|
||||
|
||||
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
|
||||
assert self.batch is not None
|
||||
|
|
37
llama_cpp/_logger.py
Normal file
37
llama_cpp/_logger.py
Normal file
|
@ -0,0 +1,37 @@
|
|||
import sys
|
||||
import ctypes
|
||||
import logging
|
||||
|
||||
import llama_cpp
|
||||
|
||||
# enum ggml_log_level {
|
||||
# GGML_LOG_LEVEL_ERROR = 2,
|
||||
# GGML_LOG_LEVEL_WARN = 3,
|
||||
# GGML_LOG_LEVEL_INFO = 4,
|
||||
# GGML_LOG_LEVEL_DEBUG = 5
|
||||
# };
|
||||
GGML_LOG_LEVEL_TO_LOGGING_LEVEL = {
|
||||
2: logging.ERROR,
|
||||
3: logging.WARNING,
|
||||
4: logging.INFO,
|
||||
5: logging.DEBUG,
|
||||
}
|
||||
|
||||
logger = logging.getLogger("llama-cpp-python")
|
||||
|
||||
|
||||
@llama_cpp.llama_log_callback
|
||||
def llama_log_callback(
|
||||
level: int,
|
||||
text: bytes,
|
||||
user_data: ctypes.c_void_p,
|
||||
):
|
||||
if logger.level <= GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level]:
|
||||
print(text.decode("utf-8"), end="", flush=True, file=sys.stderr)
|
||||
|
||||
|
||||
llama_cpp.llama_log_set(llama_log_callback, ctypes.c_void_p(0))
|
||||
|
||||
|
||||
def set_verbose(verbose: bool):
|
||||
logger.setLevel(logging.DEBUG if verbose else logging.ERROR)
|
|
@ -30,16 +30,20 @@ from .llama_cache import (
|
|||
import llama_cpp.llama_cpp as llama_cpp
|
||||
import llama_cpp.llama_chat_format as llama_chat_format
|
||||
|
||||
from llama_cpp.llama_speculative import LlamaDraftModel
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
from ._utils import suppress_stdout_stderr
|
||||
from ._internals import (
|
||||
_LlamaModel, # type: ignore
|
||||
_LlamaContext, # type: ignore
|
||||
_LlamaBatch, # type: ignore
|
||||
_LlamaTokenDataArray, # type: ignore
|
||||
_LlamaSamplingParams, # type: ignore
|
||||
_LlamaSamplingContext, # type: ignore
|
||||
)
|
||||
from ._logger import set_verbose
|
||||
|
||||
|
||||
class Llama:
|
||||
|
@ -89,6 +93,8 @@ class Llama:
|
|||
# Chat Format Params
|
||||
chat_format: Optional[str] = None,
|
||||
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
|
||||
# Speculative Decoding
|
||||
draft_model: Optional[LlamaDraftModel] = None,
|
||||
# Misc
|
||||
verbose: bool = True,
|
||||
# Extra Params
|
||||
|
@ -152,6 +158,7 @@ class Llama:
|
|||
numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
|
||||
chat_format: String specifying the chat format to use when calling create_chat_completion.
|
||||
chat_handler: Optional chat handler to use when calling create_chat_completion.
|
||||
draft_model: Optional draft model to use for speculative decoding.
|
||||
verbose: Print verbose output to stderr.
|
||||
|
||||
Raises:
|
||||
|
@ -162,10 +169,11 @@ class Llama:
|
|||
"""
|
||||
self.verbose = verbose
|
||||
|
||||
set_verbose(verbose)
|
||||
|
||||
self.numa = numa
|
||||
if not Llama.__backend_initialized:
|
||||
with suppress_stdout_stderr(disable=self.verbose):
|
||||
llama_cpp.llama_backend_init(self.numa)
|
||||
llama_cpp.llama_backend_init(self.numa)
|
||||
Llama.__backend_initialized = True
|
||||
|
||||
self.model_path = model_path
|
||||
|
@ -315,6 +323,8 @@ class Llama:
|
|||
self.chat_format = chat_format
|
||||
self.chat_handler = chat_handler
|
||||
|
||||
self.draft_model = draft_model
|
||||
|
||||
self._n_vocab = self.n_vocab()
|
||||
self._n_ctx = self.n_ctx()
|
||||
|
||||
|
@ -503,6 +513,7 @@ class Llama:
|
|||
penalize_nl: bool = True,
|
||||
logits_processor: Optional[LogitsProcessorList] = None,
|
||||
grammar: Optional[LlamaGrammar] = None,
|
||||
idx: Optional[int] = None,
|
||||
):
|
||||
"""Sample a token from the model.
|
||||
|
||||
|
@ -517,77 +528,46 @@ class Llama:
|
|||
"""
|
||||
assert self._ctx is not None
|
||||
assert self.n_tokens > 0
|
||||
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
|
||||
0, self.last_n_tokens_size - self.n_tokens
|
||||
) + self._input_ids[-self.last_n_tokens_size :].tolist()
|
||||
last_n_tokens_size = len(last_n_tokens_data)
|
||||
n_vocab = self._n_vocab
|
||||
n_ctx = self._n_ctx
|
||||
top_k = n_vocab if top_k <= 0 else top_k
|
||||
last_n_tokens_size = n_ctx if last_n_tokens_size < 0 else last_n_tokens_size
|
||||
last_n_tokens_data_c = (llama_cpp.llama_token * last_n_tokens_size)(
|
||||
*last_n_tokens_data
|
||||
)
|
||||
logits: npt.NDArray[np.single] = self._scores[-1, :]
|
||||
|
||||
if idx is None:
|
||||
logits: npt.NDArray[np.single] = self._scores[-1, :]
|
||||
else:
|
||||
logits = self._scores[idx, :]
|
||||
|
||||
if logits_processor is not None:
|
||||
logits[:] = logits_processor(self._input_ids, logits)
|
||||
logits[:] = (
|
||||
logits_processor(self._input_ids, logits)
|
||||
if idx is None
|
||||
else logits_processor(self._input_ids[:idx], logits)
|
||||
)
|
||||
|
||||
nl_logit = logits[self._token_nl]
|
||||
self._candidates.copy_logits(logits)
|
||||
self._ctx.sample_repetition_penalties(
|
||||
candidates=self._candidates,
|
||||
last_tokens_data=last_n_tokens_data_c,
|
||||
penalty_last_n=last_n_tokens_size,
|
||||
sampling_params = _LlamaSamplingParams(
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
tfs_z=tfs_z,
|
||||
typical_p=typical_p,
|
||||
temp=temp,
|
||||
penalty_last_n=self.last_n_tokens_size,
|
||||
penalty_repeat=repeat_penalty,
|
||||
penalty_freq=frequency_penalty,
|
||||
penalty_present=presence_penalty,
|
||||
mirostat=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
penalize_nl=penalize_nl,
|
||||
)
|
||||
sampling_context = _LlamaSamplingContext(
|
||||
params=sampling_params,
|
||||
grammar=grammar,
|
||||
)
|
||||
sampling_context.prev = list(self.eval_tokens)
|
||||
id = sampling_context.sample(ctx_main=self._ctx, logits_array=logits)
|
||||
sampling_context.accept(
|
||||
ctx_main=self._ctx,
|
||||
id=id,
|
||||
apply_grammar=grammar is not None,
|
||||
)
|
||||
if not penalize_nl:
|
||||
self._candidates.candidates.data[self._token_nl].logit = llama_cpp.c_float(
|
||||
nl_logit
|
||||
)
|
||||
|
||||
if grammar is not None:
|
||||
self._ctx.sample_grammar(
|
||||
candidates=self._candidates,
|
||||
grammar=grammar,
|
||||
)
|
||||
|
||||
if temp < 0.0:
|
||||
self._ctx.sample_softmax(candidates=self._candidates)
|
||||
id = self._candidates.candidates.data[0].id
|
||||
elif temp == 0.0:
|
||||
id = self._ctx.sample_token_greedy(candidates=self._candidates)
|
||||
elif mirostat_mode == 1:
|
||||
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
|
||||
id = self._ctx.sample_token_mirostat(
|
||||
candidates=self._candidates,
|
||||
tau=mirostat_tau,
|
||||
eta=mirostat_eta,
|
||||
mu=ctypes.pointer(self._mirostat_mu),
|
||||
m=100,
|
||||
)
|
||||
elif mirostat_mode == 2:
|
||||
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
|
||||
id = self._ctx.sample_token_mirostat_v2(
|
||||
candidates=self._candidates,
|
||||
tau=mirostat_tau,
|
||||
eta=mirostat_eta,
|
||||
mu=ctypes.pointer(self._mirostat_mu),
|
||||
)
|
||||
else:
|
||||
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
|
||||
self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1)
|
||||
self._ctx.sample_typical(
|
||||
candidates=self._candidates, p=typical_p, min_keep=1
|
||||
)
|
||||
self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1)
|
||||
self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1)
|
||||
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
|
||||
id = self._ctx.sample_token(candidates=self._candidates)
|
||||
if grammar is not None:
|
||||
self._ctx.grammar_accept_token(grammar=grammar, token=id)
|
||||
return id
|
||||
|
||||
def generate(
|
||||
|
@ -656,34 +636,56 @@ class Llama:
|
|||
if grammar is not None:
|
||||
grammar.reset()
|
||||
|
||||
sample_idx = self.n_tokens + len(tokens) - 1
|
||||
tokens = list(tokens)
|
||||
|
||||
# Eval and sample
|
||||
while True:
|
||||
self.eval(tokens)
|
||||
token = self.sample(
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
temp=temp,
|
||||
repeat_penalty=repeat_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
tfs_z=tfs_z,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
penalize_nl=penalize_nl,
|
||||
)
|
||||
if stopping_criteria is not None and stopping_criteria(
|
||||
self._input_ids, self._scores[-1, :]
|
||||
):
|
||||
return
|
||||
tokens_or_none = yield token
|
||||
tokens = [token]
|
||||
if tokens_or_none is not None:
|
||||
tokens.extend(tokens_or_none)
|
||||
while sample_idx < self.n_tokens:
|
||||
token = self.sample(
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
min_p=min_p,
|
||||
typical_p=typical_p,
|
||||
temp=temp,
|
||||
repeat_penalty=repeat_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
presence_penalty=presence_penalty,
|
||||
tfs_z=tfs_z,
|
||||
mirostat_mode=mirostat_mode,
|
||||
mirostat_tau=mirostat_tau,
|
||||
mirostat_eta=mirostat_eta,
|
||||
logits_processor=logits_processor,
|
||||
grammar=grammar,
|
||||
penalize_nl=penalize_nl,
|
||||
idx=sample_idx,
|
||||
)
|
||||
|
||||
sample_idx += 1
|
||||
if stopping_criteria is not None and stopping_criteria(
|
||||
self._input_ids, self._scores[-1, :]
|
||||
):
|
||||
return
|
||||
tokens_or_none = yield token
|
||||
tokens.clear()
|
||||
tokens.append(token)
|
||||
if tokens_or_none is not None:
|
||||
tokens.extend(tokens_or_none)
|
||||
|
||||
if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]:
|
||||
self.n_tokens = sample_idx
|
||||
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
|
||||
break
|
||||
|
||||
if self.draft_model is not None:
|
||||
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
|
||||
draft_tokens = self.draft_model(self.input_ids[:self.n_tokens + len(tokens)])
|
||||
tokens.extend(
|
||||
draft_tokens.astype(int)[
|
||||
: self._n_ctx - self.n_tokens - len(tokens)
|
||||
]
|
||||
)
|
||||
|
||||
def create_embedding(
|
||||
self, input: Union[str, List[str]], model: Optional[str] = None
|
||||
|
|
|
@ -445,7 +445,7 @@ class llama_model_params(Structure):
|
|||
# uint32_t n_batch; // prompt processing maximum batch size
|
||||
# uint32_t n_threads; // number of threads to use for generation
|
||||
# uint32_t n_threads_batch; // number of threads to use for batch processing
|
||||
# int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||
# int32_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
|
||||
|
||||
# // ref: https://github.com/ggerganov/llama.cpp/pull/2054
|
||||
# float rope_freq_base; // RoPE base frequency, 0 = from model
|
||||
|
@ -502,7 +502,7 @@ class llama_context_params(Structure):
|
|||
("n_batch", c_uint32),
|
||||
("n_threads", c_uint32),
|
||||
("n_threads_batch", c_uint32),
|
||||
("rope_scaling_type", c_int8),
|
||||
("rope_scaling_type", c_int32),
|
||||
("rope_freq_base", c_float),
|
||||
("rope_freq_scale", c_float),
|
||||
("yarn_ext_factor", c_float),
|
||||
|
|
64
llama_cpp/llama_speculative.py
Normal file
64
llama_cpp/llama_speculative.py
Normal file
|
@ -0,0 +1,64 @@
|
|||
import abc
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
|
||||
class LlamaDraftModel(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def __call__(
|
||||
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
|
||||
) -> npt.NDArray[np.intc]:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class LlamaPromptLookupDecoding(LlamaDraftModel):
|
||||
"""Based on https://github.com/apoorvumang/prompt-lookup-decoding"""
|
||||
|
||||
def __init__(self, max_ngram_size: int = 2, num_pred_tokens: int = 10):
|
||||
self.max_ngram_size = max_ngram_size
|
||||
self.num_pred_tokens = num_pred_tokens
|
||||
|
||||
@staticmethod
|
||||
def find_candidate_pred_tokens(
|
||||
input_ids: npt.NDArray[np.intc],
|
||||
max_ngram_size: int,
|
||||
num_pred_tokens: int,
|
||||
):
|
||||
input_length = input_ids.shape[0]
|
||||
|
||||
for ngram_size in range(min(max_ngram_size, input_length - 1), 0, -1):
|
||||
# Create sliding windows of size ngram_size
|
||||
windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,))
|
||||
|
||||
# Convert ngram to an array for comparison
|
||||
ngram_array = input_ids[-ngram_size:]
|
||||
|
||||
# Find where the windows match the ngram
|
||||
matches = np.all(windows == ngram_array, axis=1)
|
||||
|
||||
# Get the indices of matches
|
||||
match_indices = np.nonzero(matches)[0]
|
||||
|
||||
# Iterate through match indices to find a valid continuation
|
||||
for idx in match_indices:
|
||||
start_idx = idx + ngram_size
|
||||
end_idx = start_idx + num_pred_tokens
|
||||
end_idx = min(end_idx, input_length)
|
||||
|
||||
if start_idx < end_idx:
|
||||
return input_ids[start_idx:end_idx]
|
||||
|
||||
# If no match is found, return an empty array
|
||||
return np.array([], dtype=np.intc)
|
||||
|
||||
def __call__(
|
||||
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
|
||||
) -> npt.NDArray[np.intc]:
|
||||
return self.find_candidate_pred_tokens(
|
||||
input_ids=input_ids,
|
||||
max_ngram_size=self.max_ngram_size,
|
||||
num_pred_tokens=self.num_pred_tokens,
|
||||
)
|
|
@ -5,6 +5,7 @@ import json
|
|||
from typing import Dict, Optional, Union, List
|
||||
|
||||
import llama_cpp
|
||||
import llama_cpp.llama_speculative as llama_speculative
|
||||
|
||||
from llama_cpp.server.settings import ModelSettings
|
||||
|
||||
|
@ -92,6 +93,12 @@ class LlamaProxy:
|
|||
)
|
||||
)
|
||||
|
||||
draft_model = None
|
||||
if settings.draft_model is not None:
|
||||
draft_model = llama_speculative.LlamaPromptLookupDecoding(
|
||||
num_pred_tokens=settings.draft_model_num_pred_tokens
|
||||
)
|
||||
|
||||
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
|
||||
if settings.kv_overrides is not None:
|
||||
assert isinstance(settings.kv_overrides, list)
|
||||
|
@ -147,6 +154,8 @@ class LlamaProxy:
|
|||
# Chat Format Params
|
||||
chat_format=settings.chat_format,
|
||||
chat_handler=chat_handler,
|
||||
# Speculative Decoding
|
||||
draft_model=draft_model,
|
||||
# Misc
|
||||
verbose=settings.verbose,
|
||||
)
|
||||
|
|
|
@ -143,6 +143,15 @@ class ModelSettings(BaseSettings):
|
|||
default=None,
|
||||
description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().",
|
||||
)
|
||||
# Speculative Decoding
|
||||
draft_model: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Method to use for speculative decoding. One of (prompt-lookup-decoding).",
|
||||
)
|
||||
draft_model_num_pred_tokens: int = Field(
|
||||
default=10,
|
||||
description="Number of tokens to predict using the draft model.",
|
||||
)
|
||||
# Misc
|
||||
verbose: bool = Field(
|
||||
default=True, description="Whether to print debug information."
|
||||
|
|
16
tests/test_llama_speculative.py
Normal file
16
tests/test_llama_speculative.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
import numpy as np
|
||||
|
||||
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
|
||||
|
||||
def test_find_candidate_pred_tokens():
|
||||
find_candidate_pred_tokens = LlamaPromptLookupDecoding.find_candidate_pred_tokens
|
||||
|
||||
# Test Case 1: Matching ngram is found
|
||||
input_ids1 = np.array([1, 2, 3, 1, 2, 3, 1, 2, 3])
|
||||
result1 = find_candidate_pred_tokens(input_ids1, max_ngram_size=3, num_pred_tokens=2)
|
||||
assert np.array_equal(result1, np.array([1, 2]))
|
||||
|
||||
# Test Case 2: Matching ngram is not found
|
||||
input_ids2 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
|
||||
result2 = find_candidate_pred_tokens(input_ids2, max_ngram_size=3, num_pred_tokens=2)
|
||||
assert np.array_equal(result2, np.array([]))
|
2
vendor/llama.cpp
vendored
2
vendor/llama.cpp
vendored
|
@ -1 +1 @@
|
|||
Subproject commit 5cb04dbc16d1da38c8fdcc0111b40e67d00dd1c3
|
||||
Subproject commit 098f6d737b65134cf220d12b9b706e8cfc5e4610
|
Loading…
Reference in a new issue