Compare commits

...

11 commits

Author SHA1 Message Date
b342398804
Merge https://github.com/abetlen/llama-cpp-python 2024-02-06 16:26:03 +05:30
Andrei Betlen
310fbf4e49 Update llama.cpp 2024-02-05 22:07:14 -05:00
Andrei Betlen
59760c85ed fix: Use llama_log_callback to avoid suppress_stdout_stderr 2024-02-05 21:52:12 -05:00
Andrei Betlen
3553b14670 Update llama.cpp 2024-02-05 13:26:50 -05:00
Andrei
7467f129e5
Revert "Fix: fileno error google colab (#729) (#1156)" (#1157)
This reverts commit bebfba0f08.
2024-02-02 12:18:55 -05:00
Dulsara
bebfba0f08
Fix: fileno error google colab (#729) (#1156)
Instead of using a devnull just made a dummy class with a 'write()' method that does nothing.
2024-02-02 12:05:46 -05:00
Andrei Betlen
8a5911bd5d Update llama.cpp 2024-02-02 09:41:27 -05:00
Andrei Betlen
de526d0214 Update llama.cpp 2024-02-01 12:35:31 -05:00
Andrei Betlen
3322eadbf3 Bump version 2024-01-31 15:10:18 -05:00
Andrei Betlen
a8cb34eacd Update llama.cpp 2024-01-31 15:05:51 -05:00
Andrei
fb762a6041
Add speculative decoding (#1120)
* Add draft model param to llama class, implement basic prompt lookup decoding draft model

* Use samplingcontext for sampling

* Use 1d array

* Use draft model for sampling

* Fix dumb mistake

* Allow for later extensions to the LlamaDraftModel api

* Cleanup

* Adaptive candidate prediction

* Update implementation to match hf transformers

* Tuning

* Fix bug where last token was not used for ngram prediction

* Remove heuristic for num_pred_tokens (no benefit)

* fix: n_candidates bug.

* Add draft_model_num_pred_tokens server setting

* Cleanup

* Update README
2024-01-31 14:08:14 -05:00
12 changed files with 277 additions and 127 deletions

View file

@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased] ## [Unreleased]
## [0.2.38]
- feat: Update llama.cpp to ggerganov/llama.cpp@1cfb5372cf5707c8ec6dde7c874f4a44a6c4c915
- feat: Add speculative decoding by @abetlen in #1120
- fix: Pass raise_exception and add_generation_prompt to jinja2 chat template 078cca0361bf5a94d2cf52ed04980d20e32d6f95
## [0.2.37] ## [0.2.37]
- feat: Update llama.cpp to ggerganov/llama.cpp@fea4fd4ba7f6b754ac795387b275e1a014a77bde - feat: Update llama.cpp to ggerganov/llama.cpp@fea4fd4ba7f6b754ac795387b275e1a014a77bde

View file

@ -378,6 +378,24 @@ Then you'll need to use a custom chat handler to load the clip model and process
) )
``` ```
### Speculative Decoding
`llama-cpp-python` supports speculative decoding which allows the model to generate completions based on a draft model.
The fastest way to use speculative decoding is through the `LlamaPromptLookupDecoding` class.
Just pass this as a draft model to the `Llama` class during initialization.
```python
from llama_cpp import Llama
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
llama = Llama(
model_path="path/to/model.gguf",
draft_model=LlamaPromptLookupDecoding(num_pred_tokens=10) # num_pred_tokens is the number of tokens to predict 10 is the default and generally good for gpu, 2 performs better for cpu-only machines.
)
```
### Adjusting the Context Window ### Adjusting the Context Window
The context window of the Llama models determines the maximum number of tokens that can be processed at once. By default, this is set to 512 tokens, but can be adjusted based on your requirements. The context window of the Llama models determines the maximum number of tokens that can be processed at once. By default, this is set to 512 tokens, but can be adjusted based on your requirements.

View file

@ -1,4 +1,4 @@
from .llama_cpp import * from .llama_cpp import *
from .llama import * from .llama import *
__version__ = "0.2.37" __version__ = "0.2.38"

View file

@ -18,8 +18,6 @@ from .llama_grammar import LlamaGrammar
import llama_cpp.llama_cpp as llama_cpp import llama_cpp.llama_cpp as llama_cpp
from ._utils import suppress_stdout_stderr
# Python wrappers over llama.h structs # Python wrappers over llama.h structs
@ -30,7 +28,6 @@ class _LlamaModel:
_llama_free_model = None _llama_free_model = None
# NOTE: this must be "saved" here to avoid exceptions when calling __del__ # NOTE: this must be "saved" here to avoid exceptions when calling __del__
_suppress_stdout_stderr = suppress_stdout_stderr
def __init__( def __init__(
self, self,
@ -48,16 +45,14 @@ class _LlamaModel:
if not os.path.exists(path_model): if not os.path.exists(path_model):
raise ValueError(f"Model path does not exist: {path_model}") raise ValueError(f"Model path does not exist: {path_model}")
with self._suppress_stdout_stderr(disable=self.verbose): self.model = llama_cpp.llama_load_model_from_file(
self.model = llama_cpp.llama_load_model_from_file( self.path_model.encode("utf-8"), self.params
self.path_model.encode("utf-8"), self.params )
)
def __del__(self): def __del__(self):
with self._suppress_stdout_stderr(disable=self.verbose): if self.model is not None and self._llama_free_model is not None:
if self.model is not None and self._llama_free_model is not None: self._llama_free_model(self.model)
self._llama_free_model(self.model) self.model = None
self.model = None
def vocab_type(self) -> int: def vocab_type(self) -> int:
assert self.model is not None assert self.model is not None
@ -240,8 +235,6 @@ class _LlamaContext:
NOTE: For stability it's recommended you use the Llama class instead.""" NOTE: For stability it's recommended you use the Llama class instead."""
_llama_free = None _llama_free = None
# NOTE: this must be "saved" here to avoid exceptions when calling __del__
_suppress_stdout_stderr = suppress_stdout_stderr
def __init__( def __init__(
self, self,
@ -256,16 +249,16 @@ class _LlamaContext:
self._llama_free = llama_cpp._lib.llama_free # type: ignore self._llama_free = llama_cpp._lib.llama_free # type: ignore
with self._suppress_stdout_stderr(disable=self.verbose): assert self.model.model is not None
self.ctx = llama_cpp.llama_new_context_with_model(
self.model.model, self.params self.ctx = llama_cpp.llama_new_context_with_model(
) self.model.model, self.params
)
def __del__(self): def __del__(self):
with self._suppress_stdout_stderr(disable=self.verbose): if self.ctx is not None and self._llama_free is not None:
if self.ctx is not None and self._llama_free is not None: self._llama_free(self.ctx)
self._llama_free(self.ctx) self.ctx = None
self.ctx = None
def n_ctx(self) -> int: def n_ctx(self) -> int:
assert self.ctx is not None assert self.ctx is not None
@ -493,8 +486,6 @@ class _LlamaContext:
class _LlamaBatch: class _LlamaBatch:
_llama_batch_free = None _llama_batch_free = None
# NOTE: this must be "saved" here to avoid exceptions when calling __del__
_suppress_stdout_stderr = suppress_stdout_stderr
def __init__( def __init__(
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
@ -506,16 +497,14 @@ class _LlamaBatch:
self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore
with self._suppress_stdout_stderr(disable=self.verbose): self.batch = llama_cpp.llama_batch_init(
self.batch = llama_cpp.llama_batch_init( self.n_tokens, self.embd, self.n_seq_max
self.n_tokens, self.embd, self.n_seq_max )
)
def __del__(self): def __del__(self):
with self._suppress_stdout_stderr(disable=self.verbose): if self.batch is not None and self._llama_batch_free is not None:
if self.batch is not None and self._llama_batch_free is not None: self._llama_batch_free(self.batch)
self._llama_batch_free(self.batch) self.batch = None
self.batch = None
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool): def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
assert self.batch is not None assert self.batch is not None

37
llama_cpp/_logger.py Normal file
View file

@ -0,0 +1,37 @@
import sys
import ctypes
import logging
import llama_cpp
# enum ggml_log_level {
# GGML_LOG_LEVEL_ERROR = 2,
# GGML_LOG_LEVEL_WARN = 3,
# GGML_LOG_LEVEL_INFO = 4,
# GGML_LOG_LEVEL_DEBUG = 5
# };
GGML_LOG_LEVEL_TO_LOGGING_LEVEL = {
2: logging.ERROR,
3: logging.WARNING,
4: logging.INFO,
5: logging.DEBUG,
}
logger = logging.getLogger("llama-cpp-python")
@llama_cpp.llama_log_callback
def llama_log_callback(
level: int,
text: bytes,
user_data: ctypes.c_void_p,
):
if logger.level <= GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level]:
print(text.decode("utf-8"), end="", flush=True, file=sys.stderr)
llama_cpp.llama_log_set(llama_log_callback, ctypes.c_void_p(0))
def set_verbose(verbose: bool):
logger.setLevel(logging.DEBUG if verbose else logging.ERROR)

View file

@ -30,16 +30,20 @@ from .llama_cache import (
import llama_cpp.llama_cpp as llama_cpp import llama_cpp.llama_cpp as llama_cpp
import llama_cpp.llama_chat_format as llama_chat_format import llama_cpp.llama_chat_format as llama_chat_format
from llama_cpp.llama_speculative import LlamaDraftModel
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
from ._utils import suppress_stdout_stderr
from ._internals import ( from ._internals import (
_LlamaModel, # type: ignore _LlamaModel, # type: ignore
_LlamaContext, # type: ignore _LlamaContext, # type: ignore
_LlamaBatch, # type: ignore _LlamaBatch, # type: ignore
_LlamaTokenDataArray, # type: ignore _LlamaTokenDataArray, # type: ignore
_LlamaSamplingParams, # type: ignore
_LlamaSamplingContext, # type: ignore
) )
from ._logger import set_verbose
class Llama: class Llama:
@ -89,6 +93,8 @@ class Llama:
# Chat Format Params # Chat Format Params
chat_format: Optional[str] = None, chat_format: Optional[str] = None,
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None, chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
# Speculative Decoding
draft_model: Optional[LlamaDraftModel] = None,
# Misc # Misc
verbose: bool = True, verbose: bool = True,
# Extra Params # Extra Params
@ -152,6 +158,7 @@ class Llama:
numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init) numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
chat_format: String specifying the chat format to use when calling create_chat_completion. chat_format: String specifying the chat format to use when calling create_chat_completion.
chat_handler: Optional chat handler to use when calling create_chat_completion. chat_handler: Optional chat handler to use when calling create_chat_completion.
draft_model: Optional draft model to use for speculative decoding.
verbose: Print verbose output to stderr. verbose: Print verbose output to stderr.
Raises: Raises:
@ -162,10 +169,11 @@ class Llama:
""" """
self.verbose = verbose self.verbose = verbose
set_verbose(verbose)
self.numa = numa self.numa = numa
if not Llama.__backend_initialized: if not Llama.__backend_initialized:
with suppress_stdout_stderr(disable=self.verbose): llama_cpp.llama_backend_init(self.numa)
llama_cpp.llama_backend_init(self.numa)
Llama.__backend_initialized = True Llama.__backend_initialized = True
self.model_path = model_path self.model_path = model_path
@ -315,6 +323,8 @@ class Llama:
self.chat_format = chat_format self.chat_format = chat_format
self.chat_handler = chat_handler self.chat_handler = chat_handler
self.draft_model = draft_model
self._n_vocab = self.n_vocab() self._n_vocab = self.n_vocab()
self._n_ctx = self.n_ctx() self._n_ctx = self.n_ctx()
@ -503,6 +513,7 @@ class Llama:
penalize_nl: bool = True, penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None, logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None, grammar: Optional[LlamaGrammar] = None,
idx: Optional[int] = None,
): ):
"""Sample a token from the model. """Sample a token from the model.
@ -517,77 +528,46 @@ class Llama:
""" """
assert self._ctx is not None assert self._ctx is not None
assert self.n_tokens > 0 assert self.n_tokens > 0
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
0, self.last_n_tokens_size - self.n_tokens if idx is None:
) + self._input_ids[-self.last_n_tokens_size :].tolist() logits: npt.NDArray[np.single] = self._scores[-1, :]
last_n_tokens_size = len(last_n_tokens_data) else:
n_vocab = self._n_vocab logits = self._scores[idx, :]
n_ctx = self._n_ctx
top_k = n_vocab if top_k <= 0 else top_k
last_n_tokens_size = n_ctx if last_n_tokens_size < 0 else last_n_tokens_size
last_n_tokens_data_c = (llama_cpp.llama_token * last_n_tokens_size)(
*last_n_tokens_data
)
logits: npt.NDArray[np.single] = self._scores[-1, :]
if logits_processor is not None: if logits_processor is not None:
logits[:] = logits_processor(self._input_ids, logits) logits[:] = (
logits_processor(self._input_ids, logits)
if idx is None
else logits_processor(self._input_ids[:idx], logits)
)
nl_logit = logits[self._token_nl] sampling_params = _LlamaSamplingParams(
self._candidates.copy_logits(logits) top_k=top_k,
self._ctx.sample_repetition_penalties( top_p=top_p,
candidates=self._candidates, min_p=min_p,
last_tokens_data=last_n_tokens_data_c, tfs_z=tfs_z,
penalty_last_n=last_n_tokens_size, typical_p=typical_p,
temp=temp,
penalty_last_n=self.last_n_tokens_size,
penalty_repeat=repeat_penalty, penalty_repeat=repeat_penalty,
penalty_freq=frequency_penalty, penalty_freq=frequency_penalty,
penalty_present=presence_penalty, penalty_present=presence_penalty,
mirostat=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
penalize_nl=penalize_nl,
)
sampling_context = _LlamaSamplingContext(
params=sampling_params,
grammar=grammar,
)
sampling_context.prev = list(self.eval_tokens)
id = sampling_context.sample(ctx_main=self._ctx, logits_array=logits)
sampling_context.accept(
ctx_main=self._ctx,
id=id,
apply_grammar=grammar is not None,
) )
if not penalize_nl:
self._candidates.candidates.data[self._token_nl].logit = llama_cpp.c_float(
nl_logit
)
if grammar is not None:
self._ctx.sample_grammar(
candidates=self._candidates,
grammar=grammar,
)
if temp < 0.0:
self._ctx.sample_softmax(candidates=self._candidates)
id = self._candidates.candidates.data[0].id
elif temp == 0.0:
id = self._ctx.sample_token_greedy(candidates=self._candidates)
elif mirostat_mode == 1:
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
id = self._ctx.sample_token_mirostat(
candidates=self._candidates,
tau=mirostat_tau,
eta=mirostat_eta,
mu=ctypes.pointer(self._mirostat_mu),
m=100,
)
elif mirostat_mode == 2:
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
id = self._ctx.sample_token_mirostat_v2(
candidates=self._candidates,
tau=mirostat_tau,
eta=mirostat_eta,
mu=ctypes.pointer(self._mirostat_mu),
)
else:
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1)
self._ctx.sample_typical(
candidates=self._candidates, p=typical_p, min_keep=1
)
self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1)
self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1)
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
id = self._ctx.sample_token(candidates=self._candidates)
if grammar is not None:
self._ctx.grammar_accept_token(grammar=grammar, token=id)
return id return id
def generate( def generate(
@ -656,34 +636,56 @@ class Llama:
if grammar is not None: if grammar is not None:
grammar.reset() grammar.reset()
sample_idx = self.n_tokens + len(tokens) - 1
tokens = list(tokens)
# Eval and sample # Eval and sample
while True: while True:
self.eval(tokens) self.eval(tokens)
token = self.sample( while sample_idx < self.n_tokens:
top_k=top_k, token = self.sample(
top_p=top_p, top_k=top_k,
min_p=min_p, top_p=top_p,
typical_p=typical_p, min_p=min_p,
temp=temp, typical_p=typical_p,
repeat_penalty=repeat_penalty, temp=temp,
frequency_penalty=frequency_penalty, repeat_penalty=repeat_penalty,
presence_penalty=presence_penalty, frequency_penalty=frequency_penalty,
tfs_z=tfs_z, presence_penalty=presence_penalty,
mirostat_mode=mirostat_mode, tfs_z=tfs_z,
mirostat_tau=mirostat_tau, mirostat_mode=mirostat_mode,
mirostat_eta=mirostat_eta, mirostat_tau=mirostat_tau,
logits_processor=logits_processor, mirostat_eta=mirostat_eta,
grammar=grammar, logits_processor=logits_processor,
penalize_nl=penalize_nl, grammar=grammar,
) penalize_nl=penalize_nl,
if stopping_criteria is not None and stopping_criteria( idx=sample_idx,
self._input_ids, self._scores[-1, :] )
):
return sample_idx += 1
tokens_or_none = yield token if stopping_criteria is not None and stopping_criteria(
tokens = [token] self._input_ids, self._scores[-1, :]
if tokens_or_none is not None: ):
tokens.extend(tokens_or_none) return
tokens_or_none = yield token
tokens.clear()
tokens.append(token)
if tokens_or_none is not None:
tokens.extend(tokens_or_none)
if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]:
self.n_tokens = sample_idx
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
break
if self.draft_model is not None:
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
draft_tokens = self.draft_model(self.input_ids[:self.n_tokens + len(tokens)])
tokens.extend(
draft_tokens.astype(int)[
: self._n_ctx - self.n_tokens - len(tokens)
]
)
def create_embedding( def create_embedding(
self, input: Union[str, List[str]], model: Optional[str] = None self, input: Union[str, List[str]], model: Optional[str] = None

View file

@ -445,7 +445,7 @@ class llama_model_params(Structure):
# uint32_t n_batch; // prompt processing maximum batch size # uint32_t n_batch; // prompt processing maximum batch size
# uint32_t n_threads; // number of threads to use for generation # uint32_t n_threads; // number of threads to use for generation
# uint32_t n_threads_batch; // number of threads to use for batch processing # uint32_t n_threads_batch; // number of threads to use for batch processing
# int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` # int32_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
# // ref: https://github.com/ggerganov/llama.cpp/pull/2054 # // ref: https://github.com/ggerganov/llama.cpp/pull/2054
# float rope_freq_base; // RoPE base frequency, 0 = from model # float rope_freq_base; // RoPE base frequency, 0 = from model
@ -502,7 +502,7 @@ class llama_context_params(Structure):
("n_batch", c_uint32), ("n_batch", c_uint32),
("n_threads", c_uint32), ("n_threads", c_uint32),
("n_threads_batch", c_uint32), ("n_threads_batch", c_uint32),
("rope_scaling_type", c_int8), ("rope_scaling_type", c_int32),
("rope_freq_base", c_float), ("rope_freq_base", c_float),
("rope_freq_scale", c_float), ("rope_freq_scale", c_float),
("yarn_ext_factor", c_float), ("yarn_ext_factor", c_float),

View file

@ -0,0 +1,64 @@
import abc
from typing import Any
import numpy as np
import numpy.typing as npt
class LlamaDraftModel(abc.ABC):
@abc.abstractmethod
def __call__(
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
) -> npt.NDArray[np.intc]:
raise NotImplementedError()
class LlamaPromptLookupDecoding(LlamaDraftModel):
"""Based on https://github.com/apoorvumang/prompt-lookup-decoding"""
def __init__(self, max_ngram_size: int = 2, num_pred_tokens: int = 10):
self.max_ngram_size = max_ngram_size
self.num_pred_tokens = num_pred_tokens
@staticmethod
def find_candidate_pred_tokens(
input_ids: npt.NDArray[np.intc],
max_ngram_size: int,
num_pred_tokens: int,
):
input_length = input_ids.shape[0]
for ngram_size in range(min(max_ngram_size, input_length - 1), 0, -1):
# Create sliding windows of size ngram_size
windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,))
# Convert ngram to an array for comparison
ngram_array = input_ids[-ngram_size:]
# Find where the windows match the ngram
matches = np.all(windows == ngram_array, axis=1)
# Get the indices of matches
match_indices = np.nonzero(matches)[0]
# Iterate through match indices to find a valid continuation
for idx in match_indices:
start_idx = idx + ngram_size
end_idx = start_idx + num_pred_tokens
end_idx = min(end_idx, input_length)
if start_idx < end_idx:
return input_ids[start_idx:end_idx]
# If no match is found, return an empty array
return np.array([], dtype=np.intc)
def __call__(
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
) -> npt.NDArray[np.intc]:
return self.find_candidate_pred_tokens(
input_ids=input_ids,
max_ngram_size=self.max_ngram_size,
num_pred_tokens=self.num_pred_tokens,
)

View file

@ -5,6 +5,7 @@ import json
from typing import Dict, Optional, Union, List from typing import Dict, Optional, Union, List
import llama_cpp import llama_cpp
import llama_cpp.llama_speculative as llama_speculative
from llama_cpp.server.settings import ModelSettings from llama_cpp.server.settings import ModelSettings
@ -92,6 +93,12 @@ class LlamaProxy:
) )
) )
draft_model = None
if settings.draft_model is not None:
draft_model = llama_speculative.LlamaPromptLookupDecoding(
num_pred_tokens=settings.draft_model_num_pred_tokens
)
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
if settings.kv_overrides is not None: if settings.kv_overrides is not None:
assert isinstance(settings.kv_overrides, list) assert isinstance(settings.kv_overrides, list)
@ -147,6 +154,8 @@ class LlamaProxy:
# Chat Format Params # Chat Format Params
chat_format=settings.chat_format, chat_format=settings.chat_format,
chat_handler=chat_handler, chat_handler=chat_handler,
# Speculative Decoding
draft_model=draft_model,
# Misc # Misc
verbose=settings.verbose, verbose=settings.verbose,
) )

View file

@ -143,6 +143,15 @@ class ModelSettings(BaseSettings):
default=None, default=None,
description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().", description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().",
) )
# Speculative Decoding
draft_model: Optional[str] = Field(
default=None,
description="Method to use for speculative decoding. One of (prompt-lookup-decoding).",
)
draft_model_num_pred_tokens: int = Field(
default=10,
description="Number of tokens to predict using the draft model.",
)
# Misc # Misc
verbose: bool = Field( verbose: bool = Field(
default=True, description="Whether to print debug information." default=True, description="Whether to print debug information."

View file

@ -0,0 +1,16 @@
import numpy as np
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
def test_find_candidate_pred_tokens():
find_candidate_pred_tokens = LlamaPromptLookupDecoding.find_candidate_pred_tokens
# Test Case 1: Matching ngram is found
input_ids1 = np.array([1, 2, 3, 1, 2, 3, 1, 2, 3])
result1 = find_candidate_pred_tokens(input_ids1, max_ngram_size=3, num_pred_tokens=2)
assert np.array_equal(result1, np.array([1, 2]))
# Test Case 2: Matching ngram is not found
input_ids2 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
result2 = find_candidate_pred_tokens(input_ids2, max_ngram_size=3, num_pred_tokens=2)
assert np.array_equal(result2, np.array([]))

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit 5cb04dbc16d1da38c8fdcc0111b40e67d00dd1c3 Subproject commit 098f6d737b65134cf220d12b9b706e8cfc5e4610