From 8eb9769f78465ae0926d5f7d28cc368b877be96d Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 26 May 2023 16:12:45 -0400 Subject: [PATCH 1/7] Add support for numpy --- llama_cpp/llama.py | 57 ++++++++++++++++++++++++++++++---------------- setup.py | 4 +--- 2 files changed, 39 insertions(+), 22 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 012bb86..6babebd 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -20,6 +20,9 @@ from collections import deque, OrderedDict from . import llama_cpp from .llama_types import * +import numpy as np +import numpy.typing as npt + class LlamaCache: """Cache for a llama.cpp model.""" @@ -73,11 +76,15 @@ class LlamaState: self, eval_tokens: Deque[int], eval_logits: Deque[List[float]], + input_ids: npt.NDArray[np.intc], + scores: npt.NDArray[np.single], llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8] llama_state_size: int, ): self.eval_tokens = eval_tokens self.eval_logits = eval_logits + self.input_ids = input_ids + self.scores = scores self.llama_state = llama_state self.llama_state_size = llama_state_size @@ -207,20 +214,14 @@ class Llama: self._n_vocab = self.n_vocab() self._n_ctx = self.n_ctx() - data = (llama_cpp.llama_token_data * self._n_vocab)( - *[ - llama_cpp.llama_token_data( - id=llama_cpp.llama_token(i), - logit=llama_cpp.c_float(0.0), - p=llama_cpp.c_float(0.0), - ) - for i in range(self._n_vocab) - ] - ) size = llama_cpp.c_size_t(self._n_vocab) - sorted = False + sorted = llama_cpp.c_bool(False) + self._candidates_data = np.array( + [], dtype=[("id", np.intc), ("logit", np.single), ("p", np.single)] + ) + self._candidates_data.resize(3, self._n_vocab) candidates = llama_cpp.llama_token_data_array( - data=data, + data=self._candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p), size=size, sorted=sorted, ) @@ -228,6 +229,9 @@ class Llama: self._token_nl = Llama.token_nl() self._token_eos = Llama.token_eos() + self._input_ids = np.array([], dtype=np.intc) + self._scores = np.ndarray((0, self._n_vocab), dtype=np.single) + def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]: """Tokenize a string. @@ -319,6 +323,9 @@ class Llama: raise RuntimeError(f"llama_eval returned {return_code}") # Save tokens self.eval_tokens.extend(batch) + self._input_ids: npt.NDArray[np.intc] = np.concatenate( + (self._input_ids, np.array(batch, dtype=np.intc)), axis=0 + ) # Save logits rows = n_tokens if self.params.logits_all else 1 n_vocab = self._n_vocab @@ -326,6 +333,9 @@ class Llama: logits_view = llama_cpp.llama_get_logits(self.ctx) logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)] self.eval_logits.extend(logits) + self._scores: npt.NDArray[np.single] = np.concatenate( + (self._scores, np.array(logits, dtype=np.single)), axis=0 + ) def _sample( self, @@ -354,18 +364,23 @@ class Llama: if last_n_tokens_size.value < 0 else last_n_tokens_size ) - logits = self.eval_logits[-1] + logits: npt.NDArray[np.single] = self._scores[-1, :] if logits_processor is not None: - logits = logits_processor(list(self.eval_tokens), logits) - self.eval_logits[-1] = logits + logits = np.array( + logits_processor(list(self.eval_tokens), logits.tolist()), + dtype=np.single, + ) + self._scores[-1, :] = logits + self.eval_logits[-1] = logits.tolist() nl_logit = logits[self._token_nl] candidates = self._candidates - for i, logit in enumerate(logits): - candidates.data[i].id = llama_cpp.llama_token(i) - candidates.data[i].logit = llama_cpp.c_float(logit) - candidates.data[i].p = llama_cpp.c_float(0.0) + candidates_data = self._candidates_data + candidates_data["id"] = np.arange(n_vocab, dtype=np.intc) # type: ignore + candidates_data["logit"] = logits + candidates_data["p"] = np.zeros(n_vocab, dtype=np.single) + candidates.data = candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p) candidates.sorted = llama_cpp.c_bool(False) candidates.size = llama_cpp.c_size_t(n_vocab) llama_cpp.llama_sample_repetition_penalty( @@ -1371,6 +1386,8 @@ class Llama: return LlamaState( eval_tokens=self.eval_tokens.copy(), eval_logits=self.eval_logits.copy(), + scores=self._scores.copy(), + input_ids=self._input_ids.copy(), llama_state=llama_state_compact, llama_state_size=n_bytes, ) @@ -1379,6 +1396,8 @@ class Llama: assert self.ctx is not None self.eval_tokens = state.eval_tokens.copy() self.eval_logits = state.eval_logits.copy() + self._scores = state.scores.copy() + self._input_ids = state.input_ids.copy() state_size = state.llama_state_size if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size: raise RuntimeError("Failed to set llama state data") diff --git a/setup.py b/setup.py index bd7192f..198dd74 100644 --- a/setup.py +++ b/setup.py @@ -16,9 +16,7 @@ setup( license="MIT", package_dir={"llama_cpp": "llama_cpp", "llama_cpp.server": "llama_cpp/server"}, packages=["llama_cpp", "llama_cpp.server"], - install_requires=[ - "typing-extensions>=4.5.0", - ], + install_requires=["typing-extensions>=4.5.0", "numpy>=1.24.2"], extras_require={ "server": ["uvicorn>=0.21.1", "fastapi>=0.95.0", "sse-starlette>=1.3.3"], }, From bd4b95da45aa129277cdba0ccdab10a1af99c2e5 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 26 May 2023 16:38:21 -0400 Subject: [PATCH 2/7] Reduce numpy version dependency --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 198dd74..c51202e 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ setup( license="MIT", package_dir={"llama_cpp": "llama_cpp", "llama_cpp.server": "llama_cpp/server"}, packages=["llama_cpp", "llama_cpp.server"], - install_requires=["typing-extensions>=4.5.0", "numpy>=1.24.2"], + install_requires=["typing-extensions>=4.5.0", "numpy>=1.20.0"], extras_require={ "server": ["uvicorn>=0.21.1", "fastapi>=0.95.0", "sse-starlette>=1.3.3"], }, From fe331ec58914feaacfa3052957fef53bbd005997 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 26 May 2023 20:03:31 -0400 Subject: [PATCH 3/7] Replace eval_logits and eval_tokens with numpy arrays --- llama_cpp/llama.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 6babebd..4f10227 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -299,6 +299,8 @@ class Llama: """Reset the model state.""" self.eval_tokens.clear() self.eval_logits.clear() + self._input_ids = np.array([], dtype=np.intc) + self._scores = np.ndarray((0, self._n_vocab), dtype=np.single) def eval(self, tokens: Sequence[int]): """Evaluate a list of tokens. @@ -310,7 +312,7 @@ class Llama: n_ctx = self._n_ctx for i in range(0, len(tokens), self.n_batch): batch = tokens[i : min(len(tokens), i + self.n_batch)] - n_past = min(n_ctx - len(batch), len(self.eval_tokens)) + n_past = min(n_ctx - len(batch), len(self._input_ids)) n_tokens = len(batch) return_code = llama_cpp.llama_eval( ctx=self.ctx, @@ -356,6 +358,7 @@ class Llama: ): assert self.ctx is not None assert len(self.eval_logits) > 0 + assert self._scores.shape[0] > 0 n_vocab = self._n_vocab n_ctx = self._n_ctx top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k @@ -368,7 +371,7 @@ class Llama: if logits_processor is not None: logits = np.array( - logits_processor(list(self.eval_tokens), logits.tolist()), + logits_processor(self._input_ids.tolist(), logits.tolist()), dtype=np.single, ) self._scores[-1, :] = logits @@ -498,8 +501,8 @@ class Llama: """ assert self.ctx is not None last_n_tokens_data = [llama_cpp.llama_token(0)] * max( - 0, self.last_n_tokens_size - len(self.eval_tokens) - ) + list(self.eval_tokens)[-self.last_n_tokens_size :] + 0, self.last_n_tokens_size - len(self._input_ids) + ) + self._input_ids[-self.last_n_tokens_size :].tolist() return self._sample( last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)( *last_n_tokens_data @@ -557,9 +560,9 @@ class Llama: """ assert self.ctx is not None - if reset and len(self.eval_tokens) > 0: + if reset and len(self._input_ids) > 0: longest_prefix = 0 - for a, b in zip(self.eval_tokens, tokens[:-1]): + for a, b in zip(self._input_ids, tokens[:-1]): if a == b: longest_prefix += 1 else: @@ -569,6 +572,8 @@ class Llama: print("Llama.generate: prefix-match hit", file=sys.stderr) reset = False tokens = tokens[longest_prefix:] + self._input_ids = self._input_ids[:longest_prefix] + self._scores = self._scores[:longest_prefix, :] for _ in range(len(self.eval_tokens) - longest_prefix): self.eval_tokens.pop() try: @@ -595,7 +600,7 @@ class Llama: logits_processor=logits_processor, ) if stopping_criteria is not None and stopping_criteria( - list(self.eval_tokens), self.eval_logits[-1] + self._input_ids.tolist(), self._scores[-1, :].tolist() ): return tokens_or_none = yield token @@ -820,7 +825,7 @@ class Llama: self.detokenize(completion_tokens[:returned_tokens]) ) token_offset = len(prompt_tokens) + returned_tokens - logits = self.eval_logits[token_offset - 1] + logits = self._scores[token_offset - 1, :].tolist() current_logprobs = Llama.logits_to_logprobs(logits) sorted_logprobs = list( sorted( @@ -869,7 +874,7 @@ class Llama: break if stopping_criteria is not None and stopping_criteria( - list(self.eval_tokens), self.eval_logits[-1] + self._input_ids.tolist(), self._scores[-1, :].tolist() ): text = self.detokenize(completion_tokens) finish_reason = "stop" @@ -899,7 +904,7 @@ class Llama: self.detokenize(completion_tokens[:returned_tokens]) ) token_offset = len(prompt_tokens) + returned_tokens - 1 - logits = self.eval_logits[token_offset] + logits = self._scores[token_offset, :].tolist() current_logprobs = Llama.logits_to_logprobs(logits) sorted_logprobs = list( sorted( @@ -1001,8 +1006,7 @@ class Llama: for token in all_tokens ] all_logprobs = [ - Llama.logits_to_logprobs(list(map(float, row))) - for row in self.eval_logits + Llama.logits_to_logprobs(row.tolist()) for row in self._scores ][token_offset:] for token, token_str, logprobs_token in zip( all_tokens, all_token_strs, all_logprobs From 7fc7bc30e712c10d633a7acf912134ae92c0fbe3 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 26 May 2023 20:12:05 -0400 Subject: [PATCH 4/7] Remove usage of eval_tokens for cache check --- llama_cpp/llama.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 4f10227..064b982 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -735,10 +735,10 @@ class Llama: try: cache_item = self.cache[prompt_tokens] cache_prefix_len = Llama.longest_token_prefix( - cache_item.eval_tokens, prompt_tokens + cache_item.input_ids.tolist(), prompt_tokens ) eval_prefix_len = Llama.longest_token_prefix( - self.eval_tokens, prompt_tokens + self._input_ids.tolist(), prompt_tokens ) if cache_prefix_len > eval_prefix_len: self.load_state(cache_item) From b0b154cfa6d22d317ad26f974f9916f79bbc78c2 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 26 May 2023 20:26:08 -0400 Subject: [PATCH 5/7] Add changelog message for numpy --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b5fbec..ccb1c7e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added first version of the changelog +- Use numpy for internal buffers to reduce memory usage and improve performance. ### Fixed From 84e313bd6e18e341f35be6c87e7151e7ce8d926d Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 26 May 2023 22:02:16 -0400 Subject: [PATCH 6/7] Align dtype to match c structs --- llama_cpp/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 3084b33..ac51ce5 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -217,7 +217,7 @@ class Llama: size = llama_cpp.c_size_t(self._n_vocab) sorted = llama_cpp.c_bool(False) self._candidates_data = np.array( - [], dtype=[("id", np.intc), ("logit", np.single), ("p", np.single)] + [], dtype=np.dtype([("id", np.intc), ("logit", np.single), ("p", np.single)], align=True) ) self._candidates_data.resize(3, self._n_vocab) candidates = llama_cpp.llama_token_data_array( From 8f2b4456ad5b7a80be9264fa94927e8a79ed16a9 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Fri, 26 May 2023 22:04:31 -0400 Subject: [PATCH 7/7] Format --- llama_cpp/llama.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index ac51ce5..18372c8 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -217,7 +217,10 @@ class Llama: size = llama_cpp.c_size_t(self._n_vocab) sorted = llama_cpp.c_bool(False) self._candidates_data = np.array( - [], dtype=np.dtype([("id", np.intc), ("logit", np.single), ("p", np.single)], align=True) + [], + dtype=np.dtype( + [("id", np.intc), ("logit", np.single), ("p", np.single)], align=True + ), ) self._candidates_data.resize(3, self._n_vocab) candidates = llama_cpp.llama_token_data_array(