From fb762a60411f53278454b4e9888c5bd9712d3779 Mon Sep 17 00:00:00 2001 From: Andrei Date: Wed, 31 Jan 2024 14:08:14 -0500 Subject: [PATCH 01/10] Add speculative decoding (#1120) * Add draft model param to llama class, implement basic prompt lookup decoding draft model * Use samplingcontext for sampling * Use 1d array * Use draft model for sampling * Fix dumb mistake * Allow for later extensions to the LlamaDraftModel api * Cleanup * Adaptive candidate prediction * Update implementation to match hf transformers * Tuning * Fix bug where last token was not used for ngram prediction * Remove heuristic for num_pred_tokens (no benefit) * fix: n_candidates bug. * Add draft_model_num_pred_tokens server setting * Cleanup * Update README --- README.md | 18 ++++ llama_cpp/llama.py | 179 ++++++++++++++++---------------- llama_cpp/llama_speculative.py | 64 ++++++++++++ llama_cpp/server/model.py | 9 ++ llama_cpp/server/settings.py | 9 ++ tests/test_llama_speculative.py | 16 +++ 6 files changed, 206 insertions(+), 89 deletions(-) create mode 100644 llama_cpp/llama_speculative.py create mode 100644 tests/test_llama_speculative.py diff --git a/README.md b/README.md index 0a77bbd..4131bb3 100644 --- a/README.md +++ b/README.md @@ -378,6 +378,24 @@ Then you'll need to use a custom chat handler to load the clip model and process ) ``` +### Speculative Decoding + +`llama-cpp-python` supports speculative decoding which allows the model to generate completions based on a draft model. + +The fastest way to use speculative decoding is through the `LlamaPromptLookupDecoding` class. + +Just pass this as a draft model to the `Llama` class during initialization. + +```python +from llama_cpp import Llama +from llama_cpp.llama_speculative import LlamaPromptLookupDecoding + +llama = Llama( + model_path="path/to/model.gguf", + draft_model=LlamaPromptLookupDecoding(num_pred_tokens=10) # num_pred_tokens is the number of tokens to predict 10 is the default and generally good for gpu, 2 performs better for cpu-only machines. +) +``` + ### Adjusting the Context Window The context window of the Llama models determines the maximum number of tokens that can be processed at once. By default, this is set to 512 tokens, but can be adjusted based on your requirements. diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index b5618c1..f00fd4f 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -30,6 +30,8 @@ from .llama_cache import ( import llama_cpp.llama_cpp as llama_cpp import llama_cpp.llama_chat_format as llama_chat_format +from llama_cpp.llama_speculative import LlamaDraftModel + import numpy as np import numpy.typing as npt @@ -39,6 +41,8 @@ from ._internals import ( _LlamaContext, # type: ignore _LlamaBatch, # type: ignore _LlamaTokenDataArray, # type: ignore + _LlamaSamplingParams, # type: ignore + _LlamaSamplingContext, # type: ignore ) @@ -89,6 +93,8 @@ class Llama: # Chat Format Params chat_format: Optional[str] = None, chat_handler: Optional[llama_chat_format.LlamaChatCompletionHandler] = None, + # Speculative Decoding + draft_model: Optional[LlamaDraftModel] = None, # Misc verbose: bool = True, # Extra Params @@ -152,6 +158,7 @@ class Llama: numa: Enable NUMA support. (NOTE: The initial value of this parameter is used for the remainder of the program as this value is set in llama_backend_init) chat_format: String specifying the chat format to use when calling create_chat_completion. chat_handler: Optional chat handler to use when calling create_chat_completion. + draft_model: Optional draft model to use for speculative decoding. verbose: Print verbose output to stderr. Raises: @@ -315,6 +322,8 @@ class Llama: self.chat_format = chat_format self.chat_handler = chat_handler + self.draft_model = draft_model + self._n_vocab = self.n_vocab() self._n_ctx = self.n_ctx() @@ -503,6 +512,7 @@ class Llama: penalize_nl: bool = True, logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, + idx: Optional[int] = None, ): """Sample a token from the model. @@ -517,77 +527,46 @@ class Llama: """ assert self._ctx is not None assert self.n_tokens > 0 - last_n_tokens_data = [llama_cpp.llama_token(0)] * max( - 0, self.last_n_tokens_size - self.n_tokens - ) + self._input_ids[-self.last_n_tokens_size :].tolist() - last_n_tokens_size = len(last_n_tokens_data) - n_vocab = self._n_vocab - n_ctx = self._n_ctx - top_k = n_vocab if top_k <= 0 else top_k - last_n_tokens_size = n_ctx if last_n_tokens_size < 0 else last_n_tokens_size - last_n_tokens_data_c = (llama_cpp.llama_token * last_n_tokens_size)( - *last_n_tokens_data - ) - logits: npt.NDArray[np.single] = self._scores[-1, :] + + if idx is None: + logits: npt.NDArray[np.single] = self._scores[-1, :] + else: + logits = self._scores[idx, :] if logits_processor is not None: - logits[:] = logits_processor(self._input_ids, logits) + logits[:] = ( + logits_processor(self._input_ids, logits) + if idx is None + else logits_processor(self._input_ids[:idx], logits) + ) - nl_logit = logits[self._token_nl] - self._candidates.copy_logits(logits) - self._ctx.sample_repetition_penalties( - candidates=self._candidates, - last_tokens_data=last_n_tokens_data_c, - penalty_last_n=last_n_tokens_size, + sampling_params = _LlamaSamplingParams( + top_k=top_k, + top_p=top_p, + min_p=min_p, + tfs_z=tfs_z, + typical_p=typical_p, + temp=temp, + penalty_last_n=self.last_n_tokens_size, penalty_repeat=repeat_penalty, penalty_freq=frequency_penalty, penalty_present=presence_penalty, + mirostat=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + penalize_nl=penalize_nl, + ) + sampling_context = _LlamaSamplingContext( + params=sampling_params, + grammar=grammar, + ) + sampling_context.prev = list(self.eval_tokens) + id = sampling_context.sample(ctx_main=self._ctx, logits_array=logits) + sampling_context.accept( + ctx_main=self._ctx, + id=id, + apply_grammar=grammar is not None, ) - if not penalize_nl: - self._candidates.candidates.data[self._token_nl].logit = llama_cpp.c_float( - nl_logit - ) - - if grammar is not None: - self._ctx.sample_grammar( - candidates=self._candidates, - grammar=grammar, - ) - - if temp < 0.0: - self._ctx.sample_softmax(candidates=self._candidates) - id = self._candidates.candidates.data[0].id - elif temp == 0.0: - id = self._ctx.sample_token_greedy(candidates=self._candidates) - elif mirostat_mode == 1: - self._ctx.sample_temp(candidates=self._candidates, temp=temp) - id = self._ctx.sample_token_mirostat( - candidates=self._candidates, - tau=mirostat_tau, - eta=mirostat_eta, - mu=ctypes.pointer(self._mirostat_mu), - m=100, - ) - elif mirostat_mode == 2: - self._ctx.sample_temp(candidates=self._candidates, temp=temp) - id = self._ctx.sample_token_mirostat_v2( - candidates=self._candidates, - tau=mirostat_tau, - eta=mirostat_eta, - mu=ctypes.pointer(self._mirostat_mu), - ) - else: - self._ctx.sample_top_k(candidates=self._candidates, k=top_k, min_keep=1) - self._ctx.sample_tail_free(candidates=self._candidates, z=tfs_z, min_keep=1) - self._ctx.sample_typical( - candidates=self._candidates, p=typical_p, min_keep=1 - ) - self._ctx.sample_top_p(candidates=self._candidates, p=top_p, min_keep=1) - self._ctx.sample_min_p(candidates=self._candidates, p=min_p, min_keep=1) - self._ctx.sample_temp(candidates=self._candidates, temp=temp) - id = self._ctx.sample_token(candidates=self._candidates) - if grammar is not None: - self._ctx.grammar_accept_token(grammar=grammar, token=id) return id def generate( @@ -656,34 +635,56 @@ class Llama: if grammar is not None: grammar.reset() + sample_idx = self.n_tokens + len(tokens) - 1 + tokens = list(tokens) + # Eval and sample while True: self.eval(tokens) - token = self.sample( - top_k=top_k, - top_p=top_p, - min_p=min_p, - typical_p=typical_p, - temp=temp, - repeat_penalty=repeat_penalty, - frequency_penalty=frequency_penalty, - presence_penalty=presence_penalty, - tfs_z=tfs_z, - mirostat_mode=mirostat_mode, - mirostat_tau=mirostat_tau, - mirostat_eta=mirostat_eta, - logits_processor=logits_processor, - grammar=grammar, - penalize_nl=penalize_nl, - ) - if stopping_criteria is not None and stopping_criteria( - self._input_ids, self._scores[-1, :] - ): - return - tokens_or_none = yield token - tokens = [token] - if tokens_or_none is not None: - tokens.extend(tokens_or_none) + while sample_idx < self.n_tokens: + token = self.sample( + top_k=top_k, + top_p=top_p, + min_p=min_p, + typical_p=typical_p, + temp=temp, + repeat_penalty=repeat_penalty, + frequency_penalty=frequency_penalty, + presence_penalty=presence_penalty, + tfs_z=tfs_z, + mirostat_mode=mirostat_mode, + mirostat_tau=mirostat_tau, + mirostat_eta=mirostat_eta, + logits_processor=logits_processor, + grammar=grammar, + penalize_nl=penalize_nl, + idx=sample_idx, + ) + + sample_idx += 1 + if stopping_criteria is not None and stopping_criteria( + self._input_ids, self._scores[-1, :] + ): + return + tokens_or_none = yield token + tokens.clear() + tokens.append(token) + if tokens_or_none is not None: + tokens.extend(tokens_or_none) + + if sample_idx < self.n_tokens and token != self._input_ids[sample_idx]: + self.n_tokens = sample_idx + self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) + break + + if self.draft_model is not None: + self.input_ids[self.n_tokens : self.n_tokens + len(tokens)] = tokens + draft_tokens = self.draft_model(self.input_ids[:self.n_tokens + len(tokens)]) + tokens.extend( + draft_tokens.astype(int)[ + : self._n_ctx - self.n_tokens - len(tokens) + ] + ) def create_embedding( self, input: Union[str, List[str]], model: Optional[str] = None diff --git a/llama_cpp/llama_speculative.py b/llama_cpp/llama_speculative.py new file mode 100644 index 0000000..39dfb90 --- /dev/null +++ b/llama_cpp/llama_speculative.py @@ -0,0 +1,64 @@ +import abc + +from typing import Any + +import numpy as np +import numpy.typing as npt + + +class LlamaDraftModel(abc.ABC): + @abc.abstractmethod + def __call__( + self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any + ) -> npt.NDArray[np.intc]: + raise NotImplementedError() + + +class LlamaPromptLookupDecoding(LlamaDraftModel): + """Based on https://github.com/apoorvumang/prompt-lookup-decoding""" + + def __init__(self, max_ngram_size: int = 2, num_pred_tokens: int = 10): + self.max_ngram_size = max_ngram_size + self.num_pred_tokens = num_pred_tokens + + @staticmethod + def find_candidate_pred_tokens( + input_ids: npt.NDArray[np.intc], + max_ngram_size: int, + num_pred_tokens: int, + ): + input_length = input_ids.shape[0] + + for ngram_size in range(min(max_ngram_size, input_length - 1), 0, -1): + # Create sliding windows of size ngram_size + windows = np.lib.stride_tricks.sliding_window_view(input_ids, (ngram_size,)) + + # Convert ngram to an array for comparison + ngram_array = input_ids[-ngram_size:] + + # Find where the windows match the ngram + matches = np.all(windows == ngram_array, axis=1) + + # Get the indices of matches + match_indices = np.nonzero(matches)[0] + + # Iterate through match indices to find a valid continuation + for idx in match_indices: + start_idx = idx + ngram_size + end_idx = start_idx + num_pred_tokens + end_idx = min(end_idx, input_length) + + if start_idx < end_idx: + return input_ids[start_idx:end_idx] + + # If no match is found, return an empty array + return np.array([], dtype=np.intc) + + def __call__( + self, input_ids: npt.NDArray[np.intc], /, **kwargs: Any + ) -> npt.NDArray[np.intc]: + return self.find_candidate_pred_tokens( + input_ids=input_ids, + max_ngram_size=self.max_ngram_size, + num_pred_tokens=self.num_pred_tokens, + ) diff --git a/llama_cpp/server/model.py b/llama_cpp/server/model.py index bbb6806..925ab99 100644 --- a/llama_cpp/server/model.py +++ b/llama_cpp/server/model.py @@ -5,6 +5,7 @@ import json from typing import Dict, Optional, Union, List import llama_cpp +import llama_cpp.llama_speculative as llama_speculative from llama_cpp.server.settings import ModelSettings @@ -92,6 +93,12 @@ class LlamaProxy: ) ) + draft_model = None + if settings.draft_model is not None: + draft_model = llama_speculative.LlamaPromptLookupDecoding( + num_pred_tokens=settings.draft_model_num_pred_tokens + ) + kv_overrides: Optional[Dict[str, Union[bool, int, float]]] = None if settings.kv_overrides is not None: assert isinstance(settings.kv_overrides, list) @@ -147,6 +154,8 @@ class LlamaProxy: # Chat Format Params chat_format=settings.chat_format, chat_handler=chat_handler, + # Speculative Decoding + draft_model=draft_model, # Misc verbose=settings.verbose, ) diff --git a/llama_cpp/server/settings.py b/llama_cpp/server/settings.py index 9fe1a7b..60f3eec 100644 --- a/llama_cpp/server/settings.py +++ b/llama_cpp/server/settings.py @@ -143,6 +143,15 @@ class ModelSettings(BaseSettings): default=None, description="The model name or path to a pretrained HuggingFace tokenizer model. Same as you would pass to AutoTokenizer.from_pretrained().", ) + # Speculative Decoding + draft_model: Optional[str] = Field( + default=None, + description="Method to use for speculative decoding. One of (prompt-lookup-decoding).", + ) + draft_model_num_pred_tokens: int = Field( + default=10, + description="Number of tokens to predict using the draft model.", + ) # Misc verbose: bool = Field( default=True, description="Whether to print debug information." diff --git a/tests/test_llama_speculative.py b/tests/test_llama_speculative.py new file mode 100644 index 0000000..b5d4505 --- /dev/null +++ b/tests/test_llama_speculative.py @@ -0,0 +1,16 @@ +import numpy as np + +from llama_cpp.llama_speculative import LlamaPromptLookupDecoding + +def test_find_candidate_pred_tokens(): + find_candidate_pred_tokens = LlamaPromptLookupDecoding.find_candidate_pred_tokens + + # Test Case 1: Matching ngram is found + input_ids1 = np.array([1, 2, 3, 1, 2, 3, 1, 2, 3]) + result1 = find_candidate_pred_tokens(input_ids1, max_ngram_size=3, num_pred_tokens=2) + assert np.array_equal(result1, np.array([1, 2])) + + # Test Case 2: Matching ngram is not found + input_ids2 = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) + result2 = find_candidate_pred_tokens(input_ids2, max_ngram_size=3, num_pred_tokens=2) + assert np.array_equal(result2, np.array([])) From a8cb34eacdab7e20553e82632c4c9bd1bdebe54b Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 31 Jan 2024 15:05:51 -0500 Subject: [PATCH 02/10] Update llama.cpp --- vendor/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 5cb04db..1cfb537 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 5cb04dbc16d1da38c8fdcc0111b40e67d00dd1c3 +Subproject commit 1cfb5372cf5707c8ec6dde7c874f4a44a6c4c915 From 3322eadbf30a68731f6aafe0b4d055255b46d8f7 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Wed, 31 Jan 2024 15:10:18 -0500 Subject: [PATCH 03/10] Bump version --- CHANGELOG.md | 6 ++++++ llama_cpp/__init__.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 435af43..9632210 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.2.38] + +- feat: Update llama.cpp to ggerganov/llama.cpp@1cfb5372cf5707c8ec6dde7c874f4a44a6c4c915 +- feat: Add speculative decoding by @abetlen in #1120 +- fix: Pass raise_exception and add_generation_prompt to jinja2 chat template 078cca0361bf5a94d2cf52ed04980d20e32d6f95 + ## [0.2.37] - feat: Update llama.cpp to ggerganov/llama.cpp@fea4fd4ba7f6b754ac795387b275e1a014a77bde diff --git a/llama_cpp/__init__.py b/llama_cpp/__init__.py index 4ce899c..94cd401 100644 --- a/llama_cpp/__init__.py +++ b/llama_cpp/__init__.py @@ -1,4 +1,4 @@ from .llama_cpp import * from .llama import * -__version__ = "0.2.37" \ No newline at end of file +__version__ = "0.2.38" \ No newline at end of file From de526d02143b8057286e2e10de67512c2ad3480a Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Thu, 1 Feb 2024 12:35:31 -0500 Subject: [PATCH 04/10] Update llama.cpp --- vendor/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 1cfb537..8ca511c 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 1cfb5372cf5707c8ec6dde7c874f4a44a6c4c915 +Subproject commit 8ca511cadee2c67f0bd8c7034a2513778ee9a1b7 From 8a5911bd5d3eaab67217cad6a6dce7a041f6137b Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 2 Feb 2024 09:41:27 -0500 Subject: [PATCH 05/10] Update llama.cpp --- vendor/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 8ca511c..1912211 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 8ca511cadee2c67f0bd8c7034a2513778ee9a1b7 +Subproject commit 191221178f51b6e81122c5bda0fd79620e547d07 From bebfba0f08b3198cee2f1652393f032af4bf2016 Mon Sep 17 00:00:00 2001 From: Dulsara Date: Fri, 2 Feb 2024 22:35:46 +0530 Subject: [PATCH 06/10] Fix: fileno error google colab (#729) (#1156) Instead of using a devnull just made a dummy class with a 'write()' method that does nothing. --- llama_cpp/_utils.py | 38 ++++++++------------------------------ 1 file changed, 8 insertions(+), 30 deletions(-) diff --git a/llama_cpp/_utils.py b/llama_cpp/_utils.py index 4a10647..4990c11 100644 --- a/llama_cpp/_utils.py +++ b/llama_cpp/_utils.py @@ -4,9 +4,9 @@ import sys import sys from typing import Any, Dict -# Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor -outnull_file = open(os.devnull, "w") -errnull_file = open(os.devnull, "w") +class NullDevice(): + def write(self, s): + pass class suppress_stdout_stderr(object): # NOTE: these must be "saved" here to avoid exceptions when using @@ -21,41 +21,19 @@ class suppress_stdout_stderr(object): def __enter__(self): if self.disable: return self - - # Check if sys.stdout and sys.stderr have fileno method - if not hasattr(self.sys.stdout, 'fileno') or not hasattr(self.sys.stderr, 'fileno'): - return self # Return the instance without making changes - - self.old_stdout_fileno_undup = self.sys.stdout.fileno() - self.old_stderr_fileno_undup = self.sys.stderr.fileno() - - self.old_stdout_fileno = self.os.dup(self.old_stdout_fileno_undup) - self.old_stderr_fileno = self.os.dup(self.old_stderr_fileno_undup) - self.old_stdout = self.sys.stdout self.old_stderr = self.sys.stderr - self.os.dup2(outnull_file.fileno(), self.old_stdout_fileno_undup) - self.os.dup2(errnull_file.fileno(), self.old_stderr_fileno_undup) - - self.sys.stdout = outnull_file - self.sys.stderr = errnull_file + self.sys.stdout = NullDevice() + self.sys.stderr = NullDevice() return self def __exit__(self, *_): if self.disable: return - - # Check if sys.stdout and sys.stderr have fileno method - if hasattr(self.sys.stdout, 'fileno') and hasattr(self.sys.stderr, 'fileno'): - self.sys.stdout = self.old_stdout - self.sys.stderr = self.old_stderr - - self.os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) - self.os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) - - self.os.close(self.old_stdout_fileno) - self.os.close(self.old_stderr_fileno) + + self.sys.stdout = self.old_stdout + self.sys.stderr = self.old_stderr class MetaSingleton(type): From 7467f129e5cfb61115184402af6645fd9cb4401b Mon Sep 17 00:00:00 2001 From: Andrei Date: Fri, 2 Feb 2024 12:18:55 -0500 Subject: [PATCH 07/10] Revert "Fix: fileno error google colab (#729) (#1156)" (#1157) This reverts commit bebfba0f08b3198cee2f1652393f032af4bf2016. --- llama_cpp/_utils.py | 38 ++++++++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/llama_cpp/_utils.py b/llama_cpp/_utils.py index 4990c11..4a10647 100644 --- a/llama_cpp/_utils.py +++ b/llama_cpp/_utils.py @@ -4,9 +4,9 @@ import sys import sys from typing import Any, Dict -class NullDevice(): - def write(self, s): - pass +# Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor +outnull_file = open(os.devnull, "w") +errnull_file = open(os.devnull, "w") class suppress_stdout_stderr(object): # NOTE: these must be "saved" here to avoid exceptions when using @@ -21,19 +21,41 @@ class suppress_stdout_stderr(object): def __enter__(self): if self.disable: return self + + # Check if sys.stdout and sys.stderr have fileno method + if not hasattr(self.sys.stdout, 'fileno') or not hasattr(self.sys.stderr, 'fileno'): + return self # Return the instance without making changes + + self.old_stdout_fileno_undup = self.sys.stdout.fileno() + self.old_stderr_fileno_undup = self.sys.stderr.fileno() + + self.old_stdout_fileno = self.os.dup(self.old_stdout_fileno_undup) + self.old_stderr_fileno = self.os.dup(self.old_stderr_fileno_undup) + self.old_stdout = self.sys.stdout self.old_stderr = self.sys.stderr - self.sys.stdout = NullDevice() - self.sys.stderr = NullDevice() + self.os.dup2(outnull_file.fileno(), self.old_stdout_fileno_undup) + self.os.dup2(errnull_file.fileno(), self.old_stderr_fileno_undup) + + self.sys.stdout = outnull_file + self.sys.stderr = errnull_file return self def __exit__(self, *_): if self.disable: return - - self.sys.stdout = self.old_stdout - self.sys.stderr = self.old_stderr + + # Check if sys.stdout and sys.stderr have fileno method + if hasattr(self.sys.stdout, 'fileno') and hasattr(self.sys.stderr, 'fileno'): + self.sys.stdout = self.old_stdout + self.sys.stderr = self.old_stderr + + self.os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + self.os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + self.os.close(self.old_stdout_fileno) + self.os.close(self.old_stderr_fileno) class MetaSingleton(type): From 3553b146701b88f12d96575f55a8e6ef67b9c1a6 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 5 Feb 2024 13:26:50 -0500 Subject: [PATCH 08/10] Update llama.cpp --- llama_cpp/llama_cpp.py | 4 ++-- vendor/llama.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 431a99f..da2a7f3 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -445,7 +445,7 @@ class llama_model_params(Structure): # uint32_t n_batch; // prompt processing maximum batch size # uint32_t n_threads; // number of threads to use for generation # uint32_t n_threads_batch; // number of threads to use for batch processing -# int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` +# int32_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` # // ref: https://github.com/ggerganov/llama.cpp/pull/2054 # float rope_freq_base; // RoPE base frequency, 0 = from model @@ -502,7 +502,7 @@ class llama_context_params(Structure): ("n_batch", c_uint32), ("n_threads", c_uint32), ("n_threads_batch", c_uint32), - ("rope_scaling_type", c_int8), + ("rope_scaling_type", c_int32), ("rope_freq_base", c_float), ("rope_freq_scale", c_float), ("yarn_ext_factor", c_float), diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 1912211..78b00dd 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 191221178f51b6e81122c5bda0fd79620e547d07 +Subproject commit 78b00dda6c0d62c34f5371d47718defff6ed2b22 From 59760c85eddc72dfcc1839f43760ef72c23d6874 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 5 Feb 2024 21:52:12 -0500 Subject: [PATCH 09/10] fix: Use llama_log_callback to avoid suppress_stdout_stderr --- llama_cpp/_internals.py | 51 ++++++++++++++++------------------------- llama_cpp/_logger.py | 37 ++++++++++++++++++++++++++++++ llama_cpp/llama.py | 7 +++--- 3 files changed, 61 insertions(+), 34 deletions(-) create mode 100644 llama_cpp/_logger.py diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index 651cd4c..3a71ef0 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -18,8 +18,6 @@ from .llama_grammar import LlamaGrammar import llama_cpp.llama_cpp as llama_cpp -from ._utils import suppress_stdout_stderr - # Python wrappers over llama.h structs @@ -30,7 +28,6 @@ class _LlamaModel: _llama_free_model = None # NOTE: this must be "saved" here to avoid exceptions when calling __del__ - _suppress_stdout_stderr = suppress_stdout_stderr def __init__( self, @@ -48,16 +45,14 @@ class _LlamaModel: if not os.path.exists(path_model): raise ValueError(f"Model path does not exist: {path_model}") - with self._suppress_stdout_stderr(disable=self.verbose): - self.model = llama_cpp.llama_load_model_from_file( - self.path_model.encode("utf-8"), self.params - ) + self.model = llama_cpp.llama_load_model_from_file( + self.path_model.encode("utf-8"), self.params + ) def __del__(self): - with self._suppress_stdout_stderr(disable=self.verbose): - if self.model is not None and self._llama_free_model is not None: - self._llama_free_model(self.model) - self.model = None + if self.model is not None and self._llama_free_model is not None: + self._llama_free_model(self.model) + self.model = None def vocab_type(self) -> int: assert self.model is not None @@ -240,8 +235,6 @@ class _LlamaContext: NOTE: For stability it's recommended you use the Llama class instead.""" _llama_free = None - # NOTE: this must be "saved" here to avoid exceptions when calling __del__ - _suppress_stdout_stderr = suppress_stdout_stderr def __init__( self, @@ -256,16 +249,16 @@ class _LlamaContext: self._llama_free = llama_cpp._lib.llama_free # type: ignore - with self._suppress_stdout_stderr(disable=self.verbose): - self.ctx = llama_cpp.llama_new_context_with_model( - self.model.model, self.params - ) + assert self.model.model is not None + + self.ctx = llama_cpp.llama_new_context_with_model( + self.model.model, self.params + ) def __del__(self): - with self._suppress_stdout_stderr(disable=self.verbose): - if self.ctx is not None and self._llama_free is not None: - self._llama_free(self.ctx) - self.ctx = None + if self.ctx is not None and self._llama_free is not None: + self._llama_free(self.ctx) + self.ctx = None def n_ctx(self) -> int: assert self.ctx is not None @@ -493,8 +486,6 @@ class _LlamaContext: class _LlamaBatch: _llama_batch_free = None - # NOTE: this must be "saved" here to avoid exceptions when calling __del__ - _suppress_stdout_stderr = suppress_stdout_stderr def __init__( self, *, n_tokens: int, embd: int, n_seq_max: int, verbose: bool = True @@ -506,16 +497,14 @@ class _LlamaBatch: self._llama_batch_free = llama_cpp._lib.llama_batch_free # type: ignore - with self._suppress_stdout_stderr(disable=self.verbose): - self.batch = llama_cpp.llama_batch_init( - self.n_tokens, self.embd, self.n_seq_max - ) + self.batch = llama_cpp.llama_batch_init( + self.n_tokens, self.embd, self.n_seq_max + ) def __del__(self): - with self._suppress_stdout_stderr(disable=self.verbose): - if self.batch is not None and self._llama_batch_free is not None: - self._llama_batch_free(self.batch) - self.batch = None + if self.batch is not None and self._llama_batch_free is not None: + self._llama_batch_free(self.batch) + self.batch = None def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool): assert self.batch is not None diff --git a/llama_cpp/_logger.py b/llama_cpp/_logger.py new file mode 100644 index 0000000..7638170 --- /dev/null +++ b/llama_cpp/_logger.py @@ -0,0 +1,37 @@ +import sys +import ctypes +import logging + +import llama_cpp + +# enum ggml_log_level { +# GGML_LOG_LEVEL_ERROR = 2, +# GGML_LOG_LEVEL_WARN = 3, +# GGML_LOG_LEVEL_INFO = 4, +# GGML_LOG_LEVEL_DEBUG = 5 +# }; +GGML_LOG_LEVEL_TO_LOGGING_LEVEL = { + 2: logging.ERROR, + 3: logging.WARNING, + 4: logging.INFO, + 5: logging.DEBUG, +} + +logger = logging.getLogger("llama-cpp-python") + + +@llama_cpp.llama_log_callback +def llama_log_callback( + level: int, + text: bytes, + user_data: ctypes.c_void_p, +): + if logger.level <= GGML_LOG_LEVEL_TO_LOGGING_LEVEL[level]: + print(text.decode("utf-8"), end="", flush=True, file=sys.stderr) + + +llama_cpp.llama_log_set(llama_log_callback, ctypes.c_void_p(0)) + + +def set_verbose(verbose: bool): + logger.setLevel(logging.DEBUG if verbose else logging.ERROR) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index f00fd4f..85943db 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -35,7 +35,6 @@ from llama_cpp.llama_speculative import LlamaDraftModel import numpy as np import numpy.typing as npt -from ._utils import suppress_stdout_stderr from ._internals import ( _LlamaModel, # type: ignore _LlamaContext, # type: ignore @@ -44,6 +43,7 @@ from ._internals import ( _LlamaSamplingParams, # type: ignore _LlamaSamplingContext, # type: ignore ) +from ._logger import set_verbose class Llama: @@ -169,10 +169,11 @@ class Llama: """ self.verbose = verbose + set_verbose(verbose) + self.numa = numa if not Llama.__backend_initialized: - with suppress_stdout_stderr(disable=self.verbose): - llama_cpp.llama_backend_init(self.numa) + llama_cpp.llama_backend_init(self.numa) Llama.__backend_initialized = True self.model_path = model_path From 310fbf4e494e30d6a78be1b8a6b9be81f61204e7 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 5 Feb 2024 22:07:14 -0500 Subject: [PATCH 10/10] Update llama.cpp --- vendor/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vendor/llama.cpp b/vendor/llama.cpp index 78b00dd..098f6d7 160000 --- a/vendor/llama.cpp +++ b/vendor/llama.cpp @@ -1 +1 @@ -Subproject commit 78b00dda6c0d62c34f5371d47718defff6ed2b22 +Subproject commit 098f6d737b65134cf220d12b9b706e8cfc5e4610