Merge pull request #277 from abetlen/add-numpy-support

Use numpy for internal buffers
This commit is contained in:
Andrei 2023-05-29 20:59:30 -04:00 committed by GitHub
commit 49fe9395a1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 60 additions and 35 deletions

View file

@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- Added first version of the changelog
- Use numpy for internal buffers to reduce memory usage and improve performance.
### Fixed

View file

@ -20,6 +20,9 @@ from collections import deque, OrderedDict
from . import llama_cpp
from .llama_types import *
import numpy as np
import numpy.typing as npt
class LlamaCache:
"""Cache for a llama.cpp model."""
@ -73,11 +76,15 @@ class LlamaState:
self,
eval_tokens: Deque[int],
eval_logits: Deque[List[float]],
input_ids: npt.NDArray[np.intc],
scores: npt.NDArray[np.single],
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
llama_state_size: int,
):
self.eval_tokens = eval_tokens
self.eval_logits = eval_logits
self.input_ids = input_ids
self.scores = scores
self.llama_state = llama_state
self.llama_state_size = llama_state_size
@ -207,20 +214,17 @@ class Llama:
self._n_vocab = self.n_vocab()
self._n_ctx = self.n_ctx()
data = (llama_cpp.llama_token_data * self._n_vocab)(
*[
llama_cpp.llama_token_data(
id=llama_cpp.llama_token(i),
logit=llama_cpp.c_float(0.0),
p=llama_cpp.c_float(0.0),
)
for i in range(self._n_vocab)
]
)
size = llama_cpp.c_size_t(self._n_vocab)
sorted = False
sorted = llama_cpp.c_bool(False)
self._candidates_data = np.array(
[],
dtype=np.dtype(
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
),
)
self._candidates_data.resize(3, self._n_vocab)
candidates = llama_cpp.llama_token_data_array(
data=data,
data=self._candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
size=size,
sorted=sorted,
)
@ -228,6 +232,9 @@ class Llama:
self._token_nl = Llama.token_nl()
self._token_eos = Llama.token_eos()
self._input_ids = np.array([], dtype=np.intc)
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
"""Tokenize a string.
@ -295,6 +302,8 @@ class Llama:
"""Reset the model state."""
self.eval_tokens.clear()
self.eval_logits.clear()
self._input_ids = np.array([], dtype=np.intc)
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
def eval(self, tokens: Sequence[int]):
"""Evaluate a list of tokens.
@ -306,7 +315,7 @@ class Llama:
n_ctx = self._n_ctx
for i in range(0, len(tokens), self.n_batch):
batch = tokens[i : min(len(tokens), i + self.n_batch)]
n_past = min(n_ctx - len(batch), len(self.eval_tokens))
n_past = min(n_ctx - len(batch), len(self._input_ids))
n_tokens = len(batch)
return_code = llama_cpp.llama_eval(
ctx=self.ctx,
@ -319,6 +328,9 @@ class Llama:
raise RuntimeError(f"llama_eval returned {return_code}")
# Save tokens
self.eval_tokens.extend(batch)
self._input_ids: npt.NDArray[np.intc] = np.concatenate(
(self._input_ids, np.array(batch, dtype=np.intc)), axis=0
)
# Save logits
rows = n_tokens if self.params.logits_all else 1
n_vocab = self._n_vocab
@ -326,6 +338,9 @@ class Llama:
logits_view = llama_cpp.llama_get_logits(self.ctx)
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
self.eval_logits.extend(logits)
self._scores: npt.NDArray[np.single] = np.concatenate(
(self._scores, np.array(logits, dtype=np.single)), axis=0
)
def _sample(
self,
@ -346,6 +361,7 @@ class Llama:
):
assert self.ctx is not None
assert len(self.eval_logits) > 0
assert self._scores.shape[0] > 0
n_vocab = self._n_vocab
n_ctx = self._n_ctx
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
@ -354,18 +370,23 @@ class Llama:
if last_n_tokens_size.value < 0
else last_n_tokens_size
)
logits = self.eval_logits[-1]
logits: npt.NDArray[np.single] = self._scores[-1, :]
if logits_processor is not None:
logits = logits_processor(list(self.eval_tokens), logits)
self.eval_logits[-1] = logits
logits = np.array(
logits_processor(self._input_ids.tolist(), logits.tolist()),
dtype=np.single,
)
self._scores[-1, :] = logits
self.eval_logits[-1] = logits.tolist()
nl_logit = logits[self._token_nl]
candidates = self._candidates
for i, logit in enumerate(logits):
candidates.data[i].id = llama_cpp.llama_token(i)
candidates.data[i].logit = llama_cpp.c_float(logit)
candidates.data[i].p = llama_cpp.c_float(0.0)
candidates_data = self._candidates_data
candidates_data["id"] = np.arange(n_vocab, dtype=np.intc) # type: ignore
candidates_data["logit"] = logits
candidates_data["p"] = np.zeros(n_vocab, dtype=np.single)
candidates.data = candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p)
candidates.sorted = llama_cpp.c_bool(False)
candidates.size = llama_cpp.c_size_t(n_vocab)
llama_cpp.llama_sample_repetition_penalty(
@ -483,8 +504,8 @@ class Llama:
"""
assert self.ctx is not None
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
0, self.last_n_tokens_size - len(self.eval_tokens)
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
0, self.last_n_tokens_size - len(self._input_ids)
) + self._input_ids[-self.last_n_tokens_size :].tolist()
return self._sample(
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
*last_n_tokens_data
@ -542,9 +563,9 @@ class Llama:
"""
assert self.ctx is not None
if reset and len(self.eval_tokens) > 0:
if reset and len(self._input_ids) > 0:
longest_prefix = 0
for a, b in zip(self.eval_tokens, tokens[:-1]):
for a, b in zip(self._input_ids, tokens[:-1]):
if a == b:
longest_prefix += 1
else:
@ -554,6 +575,8 @@ class Llama:
print("Llama.generate: prefix-match hit", file=sys.stderr)
reset = False
tokens = tokens[longest_prefix:]
self._input_ids = self._input_ids[:longest_prefix]
self._scores = self._scores[:longest_prefix, :]
for _ in range(len(self.eval_tokens) - longest_prefix):
self.eval_tokens.pop()
try:
@ -580,7 +603,7 @@ class Llama:
logits_processor=logits_processor,
)
if stopping_criteria is not None and stopping_criteria(
list(self.eval_tokens), self.eval_logits[-1]
self._input_ids.tolist(), self._scores[-1, :].tolist()
):
return
tokens_or_none = yield token
@ -715,10 +738,10 @@ class Llama:
try:
cache_item = self.cache[prompt_tokens]
cache_prefix_len = Llama.longest_token_prefix(
cache_item.eval_tokens, prompt_tokens
cache_item.input_ids.tolist(), prompt_tokens
)
eval_prefix_len = Llama.longest_token_prefix(
self.eval_tokens, prompt_tokens
self._input_ids.tolist(), prompt_tokens
)
if cache_prefix_len > eval_prefix_len:
self.load_state(cache_item)
@ -807,7 +830,7 @@ class Llama:
self.detokenize(completion_tokens[:returned_tokens])
)
token_offset = len(prompt_tokens) + returned_tokens
logits = self.eval_logits[token_offset - 1]
logits = self._scores[token_offset - 1, :].tolist()
current_logprobs = Llama.logits_to_logprobs(logits)
sorted_logprobs = list(
sorted(
@ -856,7 +879,7 @@ class Llama:
break
if stopping_criteria is not None and stopping_criteria(
list(self.eval_tokens), self.eval_logits[-1]
self._input_ids.tolist(), self._scores[-1, :].tolist()
):
text = self.detokenize(completion_tokens)
finish_reason = "stop"
@ -886,7 +909,7 @@ class Llama:
self.detokenize(completion_tokens[:returned_tokens])
)
token_offset = len(prompt_tokens) + returned_tokens - 1
logits = self.eval_logits[token_offset]
logits = self._scores[token_offset, :].tolist()
current_logprobs = Llama.logits_to_logprobs(logits)
sorted_logprobs = list(
sorted(
@ -988,8 +1011,7 @@ class Llama:
for token in all_tokens
]
all_logprobs = [
Llama.logits_to_logprobs(list(map(float, row)))
for row in self.eval_logits
Llama.logits_to_logprobs(row.tolist()) for row in self._scores
][token_offset:]
for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs
@ -1373,6 +1395,8 @@ class Llama:
return LlamaState(
eval_tokens=self.eval_tokens.copy(),
eval_logits=self.eval_logits.copy(),
scores=self._scores.copy(),
input_ids=self._input_ids.copy(),
llama_state=llama_state_compact,
llama_state_size=n_bytes,
)
@ -1381,6 +1405,8 @@ class Llama:
assert self.ctx is not None
self.eval_tokens = state.eval_tokens.copy()
self.eval_logits = state.eval_logits.copy()
self._scores = state.scores.copy()
self._input_ids = state.input_ids.copy()
state_size = state.llama_state_size
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
raise RuntimeError("Failed to set llama state data")

View file

@ -16,9 +16,7 @@ setup(
license="MIT",
package_dir={"llama_cpp": "llama_cpp", "llama_cpp.server": "llama_cpp/server"},
packages=["llama_cpp", "llama_cpp.server"],
install_requires=[
"typing-extensions>=4.5.0",
],
install_requires=["typing-extensions>=4.5.0", "numpy>=1.20.0"],
extras_require={
"server": ["uvicorn>=0.21.1", "fastapi>=0.95.0", "sse-starlette>=1.3.3"],
},