This commit is contained in:
baalajimaestro 2024-02-06 16:26:03 +05:30
commit b342398804
Signed by: baalajimaestro
GPG key ID: F93C394FE9BBAFD5
12 changed files with 277 additions and 127 deletions

View file

@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
## [0.2.38]
- feat: Update llama.cpp to ggerganov/llama.cpp@1cfb5372cf5707c8ec6dde7c874f4a44a6c4c915
- feat: Add speculative decoding by @abetlen in #1120
- fix: Pass raise_exception and add_generation_prompt to jinja2 chat template 078cca0361bf5a94d2cf52ed04980d20e32d6f95
## [0.2.37]
- feat: Update llama.cpp to ggerganov/llama.cpp@fea4fd4ba7f6b754ac795387b275e1a014a77bde

View file

@ -378,6 +378,24 @@ Then you'll need to use a custom chat handler to load the clip model and process
)
```
### Speculative Decoding
`llama-cpp-python` supports speculative decoding which allows the model to generate completions based on a draft model.
The fastest way to use speculative decoding is through the `LlamaPromptLookupDecoding` class.
Just pass this as a draft model to the `Llama` class during initialization.
```python
from llama_cpp import Llama
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
llama = Llama(
model_path="path/to/model.gguf",
draft_model=LlamaPromptLookupDecoding(num_pred_tokens=10) # num_pred_tokens is the number of tokens to predict 10 is the default and generally good for gpu, 2 performs better for cpu-only machines.
)
```
### Adjusting the Context Window
The context window of the Llama models determines the maximum number of tokens that can be processed at once. By default, this is set to 512 tokens, but can be adjusted based on your requirements.

View file

@ -1,4 +1,4 @@
from .llama_cpp import *
from .llama import *
__version__ = "0.2.37"
__version__ = "0.2.38"

View file

@ -18,8 +18,6 @@ from .llama_grammar import LlamaGrammar
import llama_cpp.llama_cpp as llama_cpp
from ._utils import suppress_stdout_stderr
# Python wrappers over llama.h structs
@ -30,7 +28,6 @@ class _LlamaModel:
_llama_free_model = None
# NOTE: this must be "saved" here to avoid exceptions when calling __del__
_suppress_stdout_stderr = suppress_stdout_stderr
def __init__(
self,
@ -48,13 +45,11 @@ class _LlamaModel:
if not os.path.exists(path_model):
raise ValueError(f"Model path does not exist: {path_model}")
with self._suppress_stdout_stderr(disable=self.verbose):
self.model = llama_cpp.llama_load_model_from_file(
self.path_model.encode("utf-8"), self.params
)
def __del__(self):
with self._suppress_stdout_stderr(disable=self.verbose):
if self.model is not None and self._llama_free_model is not None:
self._llama_free_model(self.model)
self.model = None
@ -240,8 +235,6 @@ class _LlamaContext:
NOTE: For stability it's recommended you use the Llama class instead."""
_llama_free = None
# NOTE: this must be "saved" here to avoid exceptions when calling __del__
_suppress_stdout_stderr = suppress_stdout_stderr
def __init__(
self,
@ -256,13 +249,13 @@ class _LlamaContext:
self._llama_free = llama_cpp._lib.llama_free # type: ignore
with self._suppress_stdout_stderr(disable=self.verbose):
assert self.model.model is not None
self.ctx = llama_cpp.llama_new_context_with_model(
self.model.model, self.params
)
def __del__(self):
with self._suppress_stdout_stderr(disable=self.verbose):
if self.ctx is not None and self._llama_free is not None:
self._llama_free(self.ctx)
self.ctx = None
@ -493,8 +486,6 @@ class _LlamaContext:
class _LlamaBatch:
_llama_batch_free = None
# NOTE: this must be "saved" here to avoid exceptions when calling __del__
_suppress_stdout_stderr = suppress_stdout_stderr
def __init__(
self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True
@ -506,13 +497,11 @@ class _LlamaBatch:
self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore
with self._suppress_stdout_stderr(disable=self.verbose):
self.batch = llama_cpp.llama_batch_init(
self.n_tokens, self.embd, self.n_seq_max
)
def __del__(self):
with self._suppress_stdout_stderr(disable=self.verbose):
if self.batch is not None and self._llama_batch_free is not None:
self._llama_batch_free(self.batch)
self.batch = None

37
llama_cpp/_logger.py Normal file
View file

@ -0,0 +1,37 @@
import sys
import ctypes
import logging
import llama_cpp
# enum ggml_log_level {
# GGML_LOG_LEVEL_ERROR = 2,
# GGML_LOG_LEVEL_WARN = 3,
# GGML_LOG_LEVEL_INFO = 4,
# GGML_LOG_LEVEL_DEBUG = 5
# };
GGML_LOG_LEVEL_TO_LOGGING_LEVEL = {
2: logging.ERROR,
3: logging.WARNING,
4: logging.INFO,
5: logging.DEBUG,
}
logger = logging.getLogger("llama-cpp-python")
@llama_cpp.llama_log_callback
def llama_log_callback(
level: int,
text: bytes,
user_data: ctypes.c_void_p,
):
if logger.level <= GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level]:
print(text.decode("utf-8"), end="", flush=True, file=sys.stderr)
llama_cpp.llama_log_set(llama_log_callback, ctypes.c_void_p(0))
def set_verbose(verbose: bool):
logger.setLevel(logging.DEBUG if verbose else logging.ERROR)

View file

@ -30,16 +30,20 @@ from .llama_cache import (
import llama_cpp.llama_cpp as llama_cpp
import llama_cpp.llama_chat_format as llama_chat_format
from llama_cpp.llama_speculative import LlamaDraftModel
import numpy as np
import numpy.typing as npt
from ._utils import suppress_stdout_stderr
from ._internals import (
_LlamaModel, # type: ignore
_LlamaContext, # type: ignore
_LlamaBatch, # type: ignore
_LlamaTokenDataArray, # type: ignore
_LlamaSamplingParams, # type: ignore
_LlamaSamplingContext, # type: ignore
)
from ._logger import set_verbose
class Llama:
@ -89,6 +93,8 @@ class Llama:
# Chat Format Params
chat_format: Optional[str] = None,
chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None,
# Speculative Decoding
draft_model: Optional[LlamaDraftModel] = None,
# Misc
verbose: bool = True,
# Extra Params
@ -152,6 +158,7 @@ class Llama:
numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init)
chat_format: String specifying the chat format to use when calling create_chat_completion.
chat_handler: Optional chat handler to use when calling create_chat_completion.
draft_model: Optional draft model to use for speculative decoding.
verbose: Print verbose output to stderr.
Raises:
@ -162,9 +169,10 @@ class Llama:
"""
self.verbose = verbose
set_verbose(verbose)
self.numa = numa
if not Llama.__backend_initialized:
with suppress_stdout_stderr(disable=self.verbose):
llama_cpp.llama_backend_init(self.numa)
Llama.__backend_initialized = True
@ -315,6 +323,8 @@ class Llama:
self.chat_format = chat_format
self.chat_handler = chat_handler
self.draft_model = draft_model
self._n_vocab = self.n_vocab()
self._n_ctx = self.n_ctx()
@ -503,6 +513,7 @@ class Llama:
penalize_nl: bool = True,
logits_processor: Optional[LogitsProcessorList] = None,
grammar: Optional[LlamaGrammar] = None,
idx: Optional[int] = None,
):
"""Sample a token from the model.
@ -517,77 +528,46 @@ class Llama:
"""
assert self._ctx is not None
assert self.n_tokens > 0
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
0, self.last_n_tokens_size - self.n_tokens
) + self._input_ids[-self.last_n_tokens_size :].tolist()
last_n_tokens_size = len(last_n_tokens_data)
n_vocab = self._n_vocab
n_ctx = self._n_ctx
top_k = n_vocab if top_k <= 0 else top_k
last_n_tokens_size = n_ctx if last_n_tokens_size < 0 else last_n_tokens_size
last_n_tokens_data_c = (llama_cpp.llama_token * last_n_tokens_size)(
*last_n_tokens_data
)
if idx is None:
logits: npt.NDArray[np.single] = self._scores[-1, :]
else:
logits = self._scores[idx, :]
if logits_processor is not None:
logits[:] = logits_processor(self._input_ids, logits)
logits[:] = (
logits_processor(self._input_ids, logits)
if idx is None
else logits_processor(self._input_ids[:idx], logits)
)
nl_logit = logits[self._token_nl]
self._candidates.copy_logits(logits)
self._ctx.sample_repetition_penalties(
candidates=self._candidates,
last_tokens_data=last_n_tokens_data_c,
penalty_last_n=last_n_tokens_size,
sampling_params = _LlamaSamplingParams(
top_k=top_k,
top_p=top_p,
min_p=min_p,
tfs_z=tfs_z,
typical_p=typical_p,
temp=temp,
penalty_last_n=self.last_n_tokens_size,
penalty_repeat=repeat_penalty,
penalty_freq=frequency_penalty,
penalty_present=presence_penalty,
mirostat=mirostat_mode,
mirostat_tau=mirostat_tau,
mirostat_eta=mirostat_eta,
penalize_nl=penalize_nl,
)
if not penalize_nl:
self._candidates.candidates.data[self._token_nl].logit = llama_cpp.c_float(
nl_logit
)
if grammar is not None:
self._ctx.sample_grammar(
candidates=self._candidates,
sampling_context = _LlamaSamplingContext(
params=sampling_params,
grammar=grammar,
)
if temp < 0.0:
self._ctx.sample_softmax(candidates=self._candidates)
id = self._candidates.candidates.data[0].id
elif temp == 0.0:
id = self._ctx.sample_token_greedy(candidates=self._candidates)
elif mirostat_mode == 1:
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
id = self._ctx.sample_token_mirostat(
candidates=self._candidates,
tau=mirostat_tau,
eta=mirostat_eta,
mu=ctypes.pointer(self._mirostat_mu),
m=100,
sampling_context.prev = list(self.eval_tokens)
id = sampling_context.sample(ctx_main=self._ctx, logits_array=logits)
sampling_context.accept(
ctx_main=self._ctx,
id=id,
apply_grammar=grammar is not None,
)
elif mirostat_mode == 2:
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
id = self._ctx.sample_token_mirostat_v2(
candidates=self._candidates,
tau=mirostat_tau,
eta=mirostat_eta,
mu=ctypes.pointer(self._mirostat_mu),
)
else:
self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1)
self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1)
self._ctx.sample_typical(
candidates=self._candidates, p=typical_p, min_keep=1
)
self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1)
self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1)
self._ctx.sample_temp(candidates=self._candidates, temp=temp)
id = self._ctx.sample_token(candidates=self._candidates)
if grammar is not None:
self._ctx.grammar_accept_token(grammar=grammar, token=id)
return id
def generate(
@ -656,9 +636,13 @@ class Llama:
if grammar is not None:
grammar.reset()
sample_idx = self.n_tokens + len(tokens) - 1
tokens = list(tokens)
# Eval and sample
while True:
self.eval(tokens)
while sample_idx < self.n_tokens:
token = self.sample(
top_k=top_k,
top_p=top_p,
@ -675,16 +659,34 @@ class Llama:
logits_processor=logits_processor,
grammar=grammar,
penalize_nl=penalize_nl,
idx=sample_idx,
)
sample_idx += 1
if stopping_criteria is not None and stopping_criteria(
self._input_ids, self._scores[-1, :]
):
return
tokens_or_none = yield token
tokens = [token]
tokens.clear()
tokens.append(token)
if tokens_or_none is not None:
tokens.extend(tokens_or_none)
if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]:
self.n_tokens = sample_idx
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
break
if self.draft_model is not None:
self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens
draft_tokens = self.draft_model(self.input_ids[:self.n_tokens + len(tokens)])
tokens.extend(
draft_tokens.astype(int)[
: self._n_ctx - self.n_tokens - len(tokens)
]
)
def create_embedding(
self, input: Union[str, List[str]], model: Optional[str] = None
) -> CreateEmbeddingResponse:

View file

@ -445,7 +445,7 @@ class llama_model_params(Structure):
# uint32_t n_batch; // prompt processing maximum batch size
# uint32_t n_threads; // number of threads to use for generation
# uint32_t n_threads_batch; // number of threads to use for batch processing
# int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
# int32_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
# // ref: https://github.com/ggerganov/llama.cpp/pull/2054
# float rope_freq_base; // RoPE base frequency, 0 = from model
@ -502,7 +502,7 @@ class llama_context_params(Structure):
("n_batch", c_uint32),
("n_threads", c_uint32),
("n_threads_batch", c_uint32),
("rope_scaling_type", c_int8),
("rope_scaling_type", c_int32),
("rope_freq_base", c_float),
("rope_freq_scale", c_float),
("yarn_ext_factor", c_float),

View file

@ -0,0 +1,64 @@
import abc
from typing import Any
import numpy as np
import numpy.typing as npt
class LlamaDraftModel(abc.ABC):
@abc.abstractmethod
def __call__(
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
) -> npt.NDArray[np.intc]:
raise NotImplementedError()
class LlamaPromptLookupDecoding(LlamaDraftModel):
"""Based on https://github.com/apoorvumang/prompt-lookup-decoding"""
def __init__(self, max_ngram_size: int = 2, num_pred_tokens: int = 10):
self.max_ngram_size = max_ngram_size
self.num_pred_tokens = num_pred_tokens
@staticmethod
def find_candidate_pred_tokens(
input_ids: npt.NDArray[np.intc],
max_ngram_size: int,
num_pred_tokens: int,
):
input_length = input_ids.shape[0]
for ngram_size in range(min(max_ngram_size, input_length - 1), 0, -1):
# Create sliding windows of size ngram_size
windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,))
# Convert ngram to an array for comparison
ngram_array = input_ids[-ngram_size:]
# Find where the windows match the ngram
matches = np.all(windows == ngram_array, axis=1)
# Get the indices of matches
match_indices = np.nonzero(matches)[0]
# Iterate through match indices to find a valid continuation
for idx in match_indices:
start_idx = idx + ngram_size
end_idx = start_idx + num_pred_tokens
end_idx = min(end_idx, input_length)
if start_idx < end_idx:
return input_ids[start_idx:end_idx]
# If no match is found, return an empty array
return np.array([], dtype=np.intc)
def __call__(
self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any
) -> npt.NDArray[np.intc]:
return self.find_candidate_pred_tokens(
input_ids=input_ids,
max_ngram_size=self.max_ngram_size,
num_pred_tokens=self.num_pred_tokens,
)

View file

@ -5,6 +5,7 @@ import json
from typing import Dict, Optional, Union, List
import llama_cpp
import llama_cpp.llama_speculative as llama_speculative
from llama_cpp.server.settings import ModelSettings
@ -92,6 +93,12 @@ class LlamaProxy:
)
)
draft_model = None
if settings.draft_model is not None:
draft_model = llama_speculative.LlamaPromptLookupDecoding(
num_pred_tokens=settings.draft_model_num_pred_tokens
)
kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None
if settings.kv_overrides is not None:
assert isinstance(settings.kv_overrides, list)
@ -147,6 +154,8 @@ class LlamaProxy:
# Chat Format Params
chat_format=settings.chat_format,
chat_handler=chat_handler,
# Speculative Decoding
draft_model=draft_model,
# Misc
verbose=settings.verbose,
)

View file

@ -143,6 +143,15 @@ class ModelSettings(BaseSettings):
default=None,
description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().",
)
# Speculative Decoding
draft_model: Optional[str] = Field(
default=None,
description="Method to use for speculative decoding. One of (prompt-lookup-decoding).",
)
draft_model_num_pred_tokens: int = Field(
default=10,
description="Number of tokens to predict using the draft model.",
)
# Misc
verbose: bool = Field(
default=True, description="Whether to print debug information."

View file

@ -0,0 +1,16 @@
import numpy as np
from llama_cpp.llama_speculative import LlamaPromptLookupDecoding
def test_find_candidate_pred_tokens():
find_candidate_pred_tokens = LlamaPromptLookupDecoding.find_candidate_pred_tokens
# Test Case 1: Matching ngram is found
input_ids1 = np.array([1, 2, 3, 1, 2, 3, 1, 2, 3])
result1 = find_candidate_pred_tokens(input_ids1, max_ngram_size=3, num_pred_tokens=2)
assert np.array_equal(result1, np.array([1, 2]))
# Test Case 2: Matching ngram is not found
input_ids2 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9])
result2 = find_candidate_pred_tokens(input_ids2, max_ngram_size=3, num_pred_tokens=2)
assert np.array_equal(result2, np.array([]))

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit 5cb04dbc16d1da38c8fdcc0111b40e67d00dd1c3
Subproject commit 098f6d737b65134cf220d12b9b706e8cfc5e4610