diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f6dac7..cb5f443 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added first version of the changelog - Server: Use async routes +- Use numpy for internal buffers to reduce memory usage and improve performance. ### Fixed diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index d7dc625..18372c8 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -20,6 +20,9 @@ from collections import deque, OrderedDict from . import llama_cpp from .llama_types import * +import numpy as np +import numpy.typing as npt + class LlamaCache: """Cache for a llama.cpp model.""" @@ -73,11 +76,15 @@ class LlamaState: self, eval_tokens: Deque[int], eval_logits: Deque[List[float]], + input_ids: npt.NDArray[np.intc], + scores: npt.NDArray[np.single], llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8] llama_state_size: int, ): self.eval_tokens = eval_tokens self.eval_logits = eval_logits + self.input_ids = input_ids + self.scores = scores self.llama_state = llama_state self.llama_state_size = llama_state_size @@ -207,20 +214,17 @@ class Llama: self._n_vocab = self.n_vocab() self._n_ctx = self.n_ctx() - data = (llama_cpp.llama_token_data * self._n_vocab)( - *[ - llama_cpp.llama_token_data( - id=llama_cpp.llama_token(i), - logit=llama_cpp.c_float(0.0), - p=llama_cpp.c_float(0.0), - ) - for i in range(self._n_vocab) - ] - ) size = llama_cpp.c_size_t(self._n_vocab) - sorted = False + sorted = llama_cpp.c_bool(False) + self._candidates_data = np.array( + [], + dtype=np.dtype( + [("id", np.intc), ("logit", np.single), ("p", np.single)], align=True + ), + ) + self._candidates_data.resize(3, self._n_vocab) candidates = llama_cpp.llama_token_data_array( - data=data, + data=self._candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p), size=size, sorted=sorted, ) @@ -228,6 +232,9 @@ class Llama: self._token_nl = Llama.token_nl() self._token_eos = Llama.token_eos() + self._input_ids = np.array([], dtype=np.intc) + self._scores = np.ndarray((0, self._n_vocab), dtype=np.single) + def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]: """Tokenize a string. @@ -295,6 +302,8 @@ class Llama: """Reset the model state.""" self.eval_tokens.clear() self.eval_logits.clear() + self._input_ids = np.array([], dtype=np.intc) + self._scores = np.ndarray((0, self._n_vocab), dtype=np.single) def eval(self, tokens: Sequence[int]): """Evaluate a list of tokens. @@ -306,7 +315,7 @@ class Llama: n_ctx = self._n_ctx for i in range(0, len(tokens), self.n_batch): batch = tokens[i : min(len(tokens), i + self.n_batch)] - n_past = min(n_ctx - len(batch), len(self.eval_tokens)) + n_past = min(n_ctx - len(batch), len(self._input_ids)) n_tokens = len(batch) return_code = llama_cpp.llama_eval( ctx=self.ctx, @@ -319,6 +328,9 @@ class Llama: raise RuntimeError(f"llama_eval returned {return_code}") # Save tokens self.eval_tokens.extend(batch) + self._input_ids: npt.NDArray[np.intc] = np.concatenate( + (self._input_ids, np.array(batch, dtype=np.intc)), axis=0 + ) # Save logits rows = n_tokens if self.params.logits_all else 1 n_vocab = self._n_vocab @@ -326,6 +338,9 @@ class Llama: logits_view = llama_cpp.llama_get_logits(self.ctx) logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)] self.eval_logits.extend(logits) + self._scores: npt.NDArray[np.single] = np.concatenate( + (self._scores, np.array(logits, dtype=np.single)), axis=0 + ) def _sample( self, @@ -346,6 +361,7 @@ class Llama: ): assert self.ctx is not None assert len(self.eval_logits) > 0 + assert self._scores.shape[0] > 0 n_vocab = self._n_vocab n_ctx = self._n_ctx top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k @@ -354,18 +370,23 @@ class Llama: if last_n_tokens_size.value < 0 else last_n_tokens_size ) - logits = self.eval_logits[-1] + logits: npt.NDArray[np.single] = self._scores[-1, :] if logits_processor is not None: - logits = logits_processor(list(self.eval_tokens), logits) - self.eval_logits[-1] = logits + logits = np.array( + logits_processor(self._input_ids.tolist(), logits.tolist()), + dtype=np.single, + ) + self._scores[-1, :] = logits + self.eval_logits[-1] = logits.tolist() nl_logit = logits[self._token_nl] candidates = self._candidates - for i, logit in enumerate(logits): - candidates.data[i].id = llama_cpp.llama_token(i) - candidates.data[i].logit = llama_cpp.c_float(logit) - candidates.data[i].p = llama_cpp.c_float(0.0) + candidates_data = self._candidates_data + candidates_data["id"] = np.arange(n_vocab, dtype=np.intc) # type: ignore + candidates_data["logit"] = logits + candidates_data["p"] = np.zeros(n_vocab, dtype=np.single) + candidates.data = candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p) candidates.sorted = llama_cpp.c_bool(False) candidates.size = llama_cpp.c_size_t(n_vocab) llama_cpp.llama_sample_repetition_penalty( @@ -483,8 +504,8 @@ class Llama: """ assert self.ctx is not None last_n_tokens_data = [llama_cpp.llama_token(0)] * max( - 0, self.last_n_tokens_size - len(self.eval_tokens) - ) + list(self.eval_tokens)[-self.last_n_tokens_size :] + 0, self.last_n_tokens_size - len(self._input_ids) + ) + self._input_ids[-self.last_n_tokens_size :].tolist() return self._sample( last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)( *last_n_tokens_data @@ -542,9 +563,9 @@ class Llama: """ assert self.ctx is not None - if reset and len(self.eval_tokens) > 0: + if reset and len(self._input_ids) > 0: longest_prefix = 0 - for a, b in zip(self.eval_tokens, tokens[:-1]): + for a, b in zip(self._input_ids, tokens[:-1]): if a == b: longest_prefix += 1 else: @@ -554,6 +575,8 @@ class Llama: print("Llama.generate: prefix-match hit", file=sys.stderr) reset = False tokens = tokens[longest_prefix:] + self._input_ids = self._input_ids[:longest_prefix] + self._scores = self._scores[:longest_prefix, :] for _ in range(len(self.eval_tokens) - longest_prefix): self.eval_tokens.pop() try: @@ -580,7 +603,7 @@ class Llama: logits_processor=logits_processor, ) if stopping_criteria is not None and stopping_criteria( - list(self.eval_tokens), self.eval_logits[-1] + self._input_ids.tolist(), self._scores[-1, :].tolist() ): return tokens_or_none = yield token @@ -715,10 +738,10 @@ class Llama: try: cache_item = self.cache[prompt_tokens] cache_prefix_len = Llama.longest_token_prefix( - cache_item.eval_tokens, prompt_tokens + cache_item.input_ids.tolist(), prompt_tokens ) eval_prefix_len = Llama.longest_token_prefix( - self.eval_tokens, prompt_tokens + self._input_ids.tolist(), prompt_tokens ) if cache_prefix_len > eval_prefix_len: self.load_state(cache_item) @@ -807,7 +830,7 @@ class Llama: self.detokenize(completion_tokens[:returned_tokens]) ) token_offset = len(prompt_tokens) + returned_tokens - logits = self.eval_logits[token_offset - 1] + logits = self._scores[token_offset - 1, :].tolist() current_logprobs = Llama.logits_to_logprobs(logits) sorted_logprobs = list( sorted( @@ -856,7 +879,7 @@ class Llama: break if stopping_criteria is not None and stopping_criteria( - list(self.eval_tokens), self.eval_logits[-1] + self._input_ids.tolist(), self._scores[-1, :].tolist() ): text = self.detokenize(completion_tokens) finish_reason = "stop" @@ -886,7 +909,7 @@ class Llama: self.detokenize(completion_tokens[:returned_tokens]) ) token_offset = len(prompt_tokens) + returned_tokens - 1 - logits = self.eval_logits[token_offset] + logits = self._scores[token_offset, :].tolist() current_logprobs = Llama.logits_to_logprobs(logits) sorted_logprobs = list( sorted( @@ -988,8 +1011,7 @@ class Llama: for token in all_tokens ] all_logprobs = [ - Llama.logits_to_logprobs(list(map(float, row))) - for row in self.eval_logits + Llama.logits_to_logprobs(row.tolist()) for row in self._scores ][token_offset:] for token, token_str, logprobs_token in zip( all_tokens, all_token_strs, all_logprobs @@ -1373,6 +1395,8 @@ class Llama: return LlamaState( eval_tokens=self.eval_tokens.copy(), eval_logits=self.eval_logits.copy(), + scores=self._scores.copy(), + input_ids=self._input_ids.copy(), llama_state=llama_state_compact, llama_state_size=n_bytes, ) @@ -1381,6 +1405,8 @@ class Llama: assert self.ctx is not None self.eval_tokens = state.eval_tokens.copy() self.eval_logits = state.eval_logits.copy() + self._scores = state.scores.copy() + self._input_ids = state.input_ids.copy() state_size = state.llama_state_size if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size: raise RuntimeError("Failed to set llama state data") diff --git a/poetry.lock b/poetry.lock index 50ae0cb..70e4272 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand. +# This file is automatically @generated by Poetry and should not be changed by hand. [[package]] name = "anyio" @@ -800,14 +800,14 @@ mkdocs = ">=1.1" [[package]] name = "mkdocs-material" -version = "9.1.14" +version = "9.1.15" description = "Documentation that simply works" category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "mkdocs_material-9.1.14-py3-none-any.whl", hash = "sha256:b56a9f955ed32d38333715cbbf68ce38f683bf38610c65094fa4ef2db9f08bcd"}, - {file = "mkdocs_material-9.1.14.tar.gz", hash = "sha256:1ae74cc5464ef2f64574d4884512efed7f4db386fb9bc6af20fd427d7a702f49"}, + {file = "mkdocs_material-9.1.15-py3-none-any.whl", hash = "sha256:b49e12869ab464558e2dd3c5792da5b748a7e0c48ee83b4d05715f98125a7a39"}, + {file = "mkdocs_material-9.1.15.tar.gz", hash = "sha256:8513ab847c9a541ed3d11a3a7eed556caf72991ee786c31c5aac6691a121088a"}, ] [package.dependencies] @@ -835,17 +835,18 @@ files = [ [[package]] name = "mkdocstrings" -version = "0.21.2" +version = "0.22.0" description = "Automatic documentation from sources, for MkDocs." category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "mkdocstrings-0.21.2-py3-none-any.whl", hash = "sha256:949ef8da92df9d692ca07be50616459a6b536083a25520fd54b00e8814ce019b"}, - {file = "mkdocstrings-0.21.2.tar.gz", hash = "sha256:304e56a2e90595708a38a13a278e538a67ad82052dd5c8b71f77a604a4f3d911"}, + {file = "mkdocstrings-0.22.0-py3-none-any.whl", hash = "sha256:2d4095d461554ff6a778fdabdca3c00c468c2f1459d469f7a7f622a2b23212ba"}, + {file = "mkdocstrings-0.22.0.tar.gz", hash = "sha256:82a33b94150ebb3d4b5c73bab4598c3e21468c79ec072eff6931c8f3bfc38256"}, ] [package.dependencies] +importlib-metadata = {version = ">=4.6", markers = "python_version < \"3.10\""} Jinja2 = ">=2.11.1" Markdown = ">=3.3" MarkupSafe = ">=1.1" @@ -1374,25 +1375,28 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] [[package]] name = "scikit-build" -version = "0.13.0" +version = "0.17.5" description = "Improved build system generator for Python C/C++/Fortran/Cython extensions" category = "dev" optional = false -python-versions = "*" +python-versions = ">=3.7" files = [ - {file = "scikit-build-0.13.0.tar.gz", hash = "sha256:a6ca1b7f1cc8a718564c19f535014f3a71f34508f72e750d4221f987eed0f06d"}, - {file = "scikit_build-0.13.0-py2.py3-none-any.whl", hash = "sha256:f903fef5cd76aa81dee040fa9cf3daaeff5c71fccfe5fc0bf6a62e54b166d492"}, + {file = "scikit_build-0.17.5-py3-none-any.whl", hash = "sha256:18861286b34fd2d685327d3bec6ebf4d33303adfaef28a08dd856710d16cf20f"}, + {file = "scikit_build-0.17.5.tar.gz", hash = "sha256:76856e7631d9e8887a7aa71913d5f184a6177246225391af96ce4801d89fa254"}, ] [package.dependencies] distro = "*" packaging = "*" -setuptools = {version = ">=28.0.0", markers = "python_version >= \"3\""} -wheel = ">=0.29.0" +setuptools = ">=42.0.0" +tomli = {version = "*", markers = "python_version < \"3.11\""} +wheel = ">=0.32.0" [package.extras] +cov = ["coverage[toml] (>=4.2)", "pytest-cov (>=2.7.1)"] docs = ["pygments", "sphinx (>=4)", "sphinx-issues", "sphinx-rtd-theme (>=1.0)", "sphinxcontrib-moderncmakedomain (>=3.19)"] -test = ["build (>=0.5)", "codecov (>=2.0.5)", "coverage (>=4.2)", "cython (>=0.25.1)", "flake8 (>=3.0.4)", "path.py (>=11.5.0)", "pathlib2", "pytest (>=4.5.0)", "pytest-cov (>=2.7.1)", "pytest-mock (>=1.10.4)", "pytest-runner (>=5.1)", "pytest-virtualenv (>=1.2.5)", "requests", "six (>=1.10.0)", "ubelt (>=0.8.2)", "virtualenv", "xdoctest (>=0.10.0)"] +doctest = ["ubelt (>=0.8.2)", "xdoctest (>=0.10.0)"] +test = ["build (>=0.7)", "cython (>=0.25.1)", "importlib-metadata", "pytest (>=6.0.0)", "pytest-mock (>=1.10.4)", "pytest-virtualenv (>=1.2.5)", "requests", "virtualenv"] [[package]] name = "secretstorage" @@ -1522,14 +1526,14 @@ urllib3 = ">=1.26.0" [[package]] name = "typing-extensions" -version = "4.5.0" +version = "4.6.2" description = "Backported and Experimental Type Hints for Python 3.7+" category = "main" optional = false python-versions = ">=3.7" files = [ - {file = "typing_extensions-4.5.0-py3-none-any.whl", hash = "sha256:fb33085c39dd998ac16d1431ebc293a8b3eedd00fd4a32de0ff79002c19511b4"}, - {file = "typing_extensions-4.5.0.tar.gz", hash = "sha256:5cb5f4a79139d699607b3ef622a1dedafa84e115ab0024e0d9c044a9479ca7cb"}, + {file = "typing_extensions-4.6.2-py3-none-any.whl", hash = "sha256:3a8b36f13dd5fdc5d1b16fe317f5668545de77fa0b8e02006381fd49d731ab98"}, + {file = "typing_extensions-4.6.2.tar.gz", hash = "sha256:06006244c70ac8ee83fa8282cb188f697b8db25bc8b4df07be1873c43897060c"}, ] [[package]] @@ -1552,14 +1556,14 @@ zstd = ["zstandard (>=0.18.0)"] [[package]] name = "uvicorn" -version = "0.21.1" +version = "0.22.0" description = "The lightning-fast ASGI server." category = "main" optional = true python-versions = ">=3.7" files = [ - {file = "uvicorn-0.21.1-py3-none-any.whl", hash = "sha256:e47cac98a6da10cd41e6fd036d472c6f58ede6c5dbee3dbee3ef7a100ed97742"}, - {file = "uvicorn-0.21.1.tar.gz", hash = "sha256:0fac9cb342ba099e0d582966005f3fdba5b0290579fed4a6266dc702ca7bb032"}, + {file = "uvicorn-0.22.0-py3-none-any.whl", hash = "sha256:e9434d3bbf05f310e762147f769c9f21235ee118ba2d2bf1155a7196448bd996"}, + {file = "uvicorn-0.22.0.tar.gz", hash = "sha256:79277ae03db57ce7d9aa0567830bbb51d7a612f54d6e1e3e92da3ef24c2c8ed8"}, ] [package.dependencies] @@ -1653,9 +1657,9 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] [extras] -server = ["fastapi", "sse-starlette", "uvicorn"] +server = ["uvicorn", "fastapi", "sse-starlette"] [metadata] lock-version = "2.0" python-versions = "^3.8.1" -content-hash = "b1b158e4c9640e4dc197fe43e22c9f87e6e90945ec9b8bcba6042f81249d251e" +content-hash = "f5aacb68729427e49bb796a598890fedd8ba1950af3fd577fb85edde2c27338f" diff --git a/pyproject.toml b/pyproject.toml index aacdac0..39b731e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,8 +14,8 @@ include = [ [tool.poetry.dependencies] python = "^3.8.1" -typing-extensions = "^4.5.0" -uvicorn = { version = "^0.21.1", optional = true } +typing-extensions = "^4.6.2" +uvicorn = { version = "^0.22.0", optional = true } fastapi = { version = "^0.95.0", optional = true } sse-starlette = { version = "^1.3.3", optional = true } @@ -23,11 +23,11 @@ sse-starlette = { version = "^1.3.3", optional = true } black = "^23.3.0" twine = "^4.0.2" mkdocs = "^1.4.3" -mkdocstrings = {extras = ["python"], version = "^0.21.2"} -mkdocs-material = "^9.1.14" +mkdocstrings = {extras = ["python"], version = "^0.22.0"} +mkdocs-material = "^9.1.15" pytest = "^7.3.1" httpx = "^0.24.1" -scikit-build = "0.13" +scikit-build = "0.17.5" [tool.poetry.extras] server = ["uvicorn", "fastapi", "sse-starlette"] diff --git a/setup.py b/setup.py index 2136d8d..a1a2c5b 100644 --- a/setup.py +++ b/setup.py @@ -16,9 +16,7 @@ setup( license="MIT", package_dir={"llama_cpp": "llama_cpp", "llama_cpp.server": "llama_cpp/server"}, packages=["llama_cpp", "llama_cpp.server"], - install_requires=[ - "typing-extensions>=4.5.0", - ], + install_requires=["typing-extensions>=4.5.0", "numpy>=1.20.0"], extras_require={ "server": ["uvicorn>=0.21.1", "fastapi>=0.95.0", "sse-starlette>=1.3.3"], },