Merge pull request #277 from abetlen/add-numpy-support
Use numpy for internal buffers
This commit is contained in:
commit
49fe9395a1
3 changed files with 60 additions and 35 deletions
|
@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||||
### Added
|
### Added
|
||||||
|
|
||||||
- Added first version of the changelog
|
- Added first version of the changelog
|
||||||
|
- Use numpy for internal buffers to reduce memory usage and improve performance.
|
||||||
|
|
||||||
### Fixed
|
### Fixed
|
||||||
|
|
||||||
|
|
|
@ -20,6 +20,9 @@ from collections import deque, OrderedDict
|
||||||
from . import llama_cpp
|
from . import llama_cpp
|
||||||
from .llama_types import *
|
from .llama_types import *
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
|
|
||||||
class LlamaCache:
|
class LlamaCache:
|
||||||
"""Cache for a llama.cpp model."""
|
"""Cache for a llama.cpp model."""
|
||||||
|
@ -73,11 +76,15 @@ class LlamaState:
|
||||||
self,
|
self,
|
||||||
eval_tokens: Deque[int],
|
eval_tokens: Deque[int],
|
||||||
eval_logits: Deque[List[float]],
|
eval_logits: Deque[List[float]],
|
||||||
|
input_ids: npt.NDArray[np.intc],
|
||||||
|
scores: npt.NDArray[np.single],
|
||||||
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
|
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
|
||||||
llama_state_size: int,
|
llama_state_size: int,
|
||||||
):
|
):
|
||||||
self.eval_tokens = eval_tokens
|
self.eval_tokens = eval_tokens
|
||||||
self.eval_logits = eval_logits
|
self.eval_logits = eval_logits
|
||||||
|
self.input_ids = input_ids
|
||||||
|
self.scores = scores
|
||||||
self.llama_state = llama_state
|
self.llama_state = llama_state
|
||||||
self.llama_state_size = llama_state_size
|
self.llama_state_size = llama_state_size
|
||||||
|
|
||||||
|
@ -207,20 +214,17 @@ class Llama:
|
||||||
|
|
||||||
self._n_vocab = self.n_vocab()
|
self._n_vocab = self.n_vocab()
|
||||||
self._n_ctx = self.n_ctx()
|
self._n_ctx = self.n_ctx()
|
||||||
data = (llama_cpp.llama_token_data * self._n_vocab)(
|
|
||||||
*[
|
|
||||||
llama_cpp.llama_token_data(
|
|
||||||
id=llama_cpp.llama_token(i),
|
|
||||||
logit=llama_cpp.c_float(0.0),
|
|
||||||
p=llama_cpp.c_float(0.0),
|
|
||||||
)
|
|
||||||
for i in range(self._n_vocab)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
size = llama_cpp.c_size_t(self._n_vocab)
|
size = llama_cpp.c_size_t(self._n_vocab)
|
||||||
sorted = False
|
sorted = llama_cpp.c_bool(False)
|
||||||
|
self._candidates_data = np.array(
|
||||||
|
[],
|
||||||
|
dtype=np.dtype(
|
||||||
|
[("id", np.intc), ("logit", np.single), ("p", np.single)], align=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
self._candidates_data.resize(3, self._n_vocab)
|
||||||
candidates = llama_cpp.llama_token_data_array(
|
candidates = llama_cpp.llama_token_data_array(
|
||||||
data=data,
|
data=self._candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
|
||||||
size=size,
|
size=size,
|
||||||
sorted=sorted,
|
sorted=sorted,
|
||||||
)
|
)
|
||||||
|
@ -228,6 +232,9 @@ class Llama:
|
||||||
self._token_nl = Llama.token_nl()
|
self._token_nl = Llama.token_nl()
|
||||||
self._token_eos = Llama.token_eos()
|
self._token_eos = Llama.token_eos()
|
||||||
|
|
||||||
|
self._input_ids = np.array([], dtype=np.intc)
|
||||||
|
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
|
||||||
|
|
||||||
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
|
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
|
||||||
"""Tokenize a string.
|
"""Tokenize a string.
|
||||||
|
|
||||||
|
@ -295,6 +302,8 @@ class Llama:
|
||||||
"""Reset the model state."""
|
"""Reset the model state."""
|
||||||
self.eval_tokens.clear()
|
self.eval_tokens.clear()
|
||||||
self.eval_logits.clear()
|
self.eval_logits.clear()
|
||||||
|
self._input_ids = np.array([], dtype=np.intc)
|
||||||
|
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
|
||||||
|
|
||||||
def eval(self, tokens: Sequence[int]):
|
def eval(self, tokens: Sequence[int]):
|
||||||
"""Evaluate a list of tokens.
|
"""Evaluate a list of tokens.
|
||||||
|
@ -306,7 +315,7 @@ class Llama:
|
||||||
n_ctx = self._n_ctx
|
n_ctx = self._n_ctx
|
||||||
for i in range(0, len(tokens), self.n_batch):
|
for i in range(0, len(tokens), self.n_batch):
|
||||||
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
||||||
n_past = min(n_ctx - len(batch), len(self.eval_tokens))
|
n_past = min(n_ctx - len(batch), len(self._input_ids))
|
||||||
n_tokens = len(batch)
|
n_tokens = len(batch)
|
||||||
return_code = llama_cpp.llama_eval(
|
return_code = llama_cpp.llama_eval(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
|
@ -319,6 +328,9 @@ class Llama:
|
||||||
raise RuntimeError(f"llama_eval returned {return_code}")
|
raise RuntimeError(f"llama_eval returned {return_code}")
|
||||||
# Save tokens
|
# Save tokens
|
||||||
self.eval_tokens.extend(batch)
|
self.eval_tokens.extend(batch)
|
||||||
|
self._input_ids: npt.NDArray[np.intc] = np.concatenate(
|
||||||
|
(self._input_ids, np.array(batch, dtype=np.intc)), axis=0
|
||||||
|
)
|
||||||
# Save logits
|
# Save logits
|
||||||
rows = n_tokens if self.params.logits_all else 1
|
rows = n_tokens if self.params.logits_all else 1
|
||||||
n_vocab = self._n_vocab
|
n_vocab = self._n_vocab
|
||||||
|
@ -326,6 +338,9 @@ class Llama:
|
||||||
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
||||||
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
|
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
|
||||||
self.eval_logits.extend(logits)
|
self.eval_logits.extend(logits)
|
||||||
|
self._scores: npt.NDArray[np.single] = np.concatenate(
|
||||||
|
(self._scores, np.array(logits, dtype=np.single)), axis=0
|
||||||
|
)
|
||||||
|
|
||||||
def _sample(
|
def _sample(
|
||||||
self,
|
self,
|
||||||
|
@ -346,6 +361,7 @@ class Llama:
|
||||||
):
|
):
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
assert len(self.eval_logits) > 0
|
assert len(self.eval_logits) > 0
|
||||||
|
assert self._scores.shape[0] > 0
|
||||||
n_vocab = self._n_vocab
|
n_vocab = self._n_vocab
|
||||||
n_ctx = self._n_ctx
|
n_ctx = self._n_ctx
|
||||||
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
|
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
|
||||||
|
@ -354,18 +370,23 @@ class Llama:
|
||||||
if last_n_tokens_size.value < 0
|
if last_n_tokens_size.value < 0
|
||||||
else last_n_tokens_size
|
else last_n_tokens_size
|
||||||
)
|
)
|
||||||
logits = self.eval_logits[-1]
|
logits: npt.NDArray[np.single] = self._scores[-1, :]
|
||||||
|
|
||||||
if logits_processor is not None:
|
if logits_processor is not None:
|
||||||
logits = logits_processor(list(self.eval_tokens), logits)
|
logits = np.array(
|
||||||
self.eval_logits[-1] = logits
|
logits_processor(self._input_ids.tolist(), logits.tolist()),
|
||||||
|
dtype=np.single,
|
||||||
|
)
|
||||||
|
self._scores[-1, :] = logits
|
||||||
|
self.eval_logits[-1] = logits.tolist()
|
||||||
|
|
||||||
nl_logit = logits[self._token_nl]
|
nl_logit = logits[self._token_nl]
|
||||||
candidates = self._candidates
|
candidates = self._candidates
|
||||||
for i, logit in enumerate(logits):
|
candidates_data = self._candidates_data
|
||||||
candidates.data[i].id = llama_cpp.llama_token(i)
|
candidates_data["id"] = np.arange(n_vocab, dtype=np.intc) # type: ignore
|
||||||
candidates.data[i].logit = llama_cpp.c_float(logit)
|
candidates_data["logit"] = logits
|
||||||
candidates.data[i].p = llama_cpp.c_float(0.0)
|
candidates_data["p"] = np.zeros(n_vocab, dtype=np.single)
|
||||||
|
candidates.data = candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p)
|
||||||
candidates.sorted = llama_cpp.c_bool(False)
|
candidates.sorted = llama_cpp.c_bool(False)
|
||||||
candidates.size = llama_cpp.c_size_t(n_vocab)
|
candidates.size = llama_cpp.c_size_t(n_vocab)
|
||||||
llama_cpp.llama_sample_repetition_penalty(
|
llama_cpp.llama_sample_repetition_penalty(
|
||||||
|
@ -483,8 +504,8 @@ class Llama:
|
||||||
"""
|
"""
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
|
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
|
||||||
0, self.last_n_tokens_size - len(self.eval_tokens)
|
0, self.last_n_tokens_size - len(self._input_ids)
|
||||||
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
|
) + self._input_ids[-self.last_n_tokens_size :].tolist()
|
||||||
return self._sample(
|
return self._sample(
|
||||||
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
|
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
|
||||||
*last_n_tokens_data
|
*last_n_tokens_data
|
||||||
|
@ -542,9 +563,9 @@ class Llama:
|
||||||
"""
|
"""
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
|
|
||||||
if reset and len(self.eval_tokens) > 0:
|
if reset and len(self._input_ids) > 0:
|
||||||
longest_prefix = 0
|
longest_prefix = 0
|
||||||
for a, b in zip(self.eval_tokens, tokens[:-1]):
|
for a, b in zip(self._input_ids, tokens[:-1]):
|
||||||
if a == b:
|
if a == b:
|
||||||
longest_prefix += 1
|
longest_prefix += 1
|
||||||
else:
|
else:
|
||||||
|
@ -554,6 +575,8 @@ class Llama:
|
||||||
print("Llama.generate: prefix-match hit", file=sys.stderr)
|
print("Llama.generate: prefix-match hit", file=sys.stderr)
|
||||||
reset = False
|
reset = False
|
||||||
tokens = tokens[longest_prefix:]
|
tokens = tokens[longest_prefix:]
|
||||||
|
self._input_ids = self._input_ids[:longest_prefix]
|
||||||
|
self._scores = self._scores[:longest_prefix, :]
|
||||||
for _ in range(len(self.eval_tokens) - longest_prefix):
|
for _ in range(len(self.eval_tokens) - longest_prefix):
|
||||||
self.eval_tokens.pop()
|
self.eval_tokens.pop()
|
||||||
try:
|
try:
|
||||||
|
@ -580,7 +603,7 @@ class Llama:
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
)
|
)
|
||||||
if stopping_criteria is not None and stopping_criteria(
|
if stopping_criteria is not None and stopping_criteria(
|
||||||
list(self.eval_tokens), self.eval_logits[-1]
|
self._input_ids.tolist(), self._scores[-1, :].tolist()
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
tokens_or_none = yield token
|
tokens_or_none = yield token
|
||||||
|
@ -715,10 +738,10 @@ class Llama:
|
||||||
try:
|
try:
|
||||||
cache_item = self.cache[prompt_tokens]
|
cache_item = self.cache[prompt_tokens]
|
||||||
cache_prefix_len = Llama.longest_token_prefix(
|
cache_prefix_len = Llama.longest_token_prefix(
|
||||||
cache_item.eval_tokens, prompt_tokens
|
cache_item.input_ids.tolist(), prompt_tokens
|
||||||
)
|
)
|
||||||
eval_prefix_len = Llama.longest_token_prefix(
|
eval_prefix_len = Llama.longest_token_prefix(
|
||||||
self.eval_tokens, prompt_tokens
|
self._input_ids.tolist(), prompt_tokens
|
||||||
)
|
)
|
||||||
if cache_prefix_len > eval_prefix_len:
|
if cache_prefix_len > eval_prefix_len:
|
||||||
self.load_state(cache_item)
|
self.load_state(cache_item)
|
||||||
|
@ -807,7 +830,7 @@ class Llama:
|
||||||
self.detokenize(completion_tokens[:returned_tokens])
|
self.detokenize(completion_tokens[:returned_tokens])
|
||||||
)
|
)
|
||||||
token_offset = len(prompt_tokens) + returned_tokens
|
token_offset = len(prompt_tokens) + returned_tokens
|
||||||
logits = self.eval_logits[token_offset - 1]
|
logits = self._scores[token_offset - 1, :].tolist()
|
||||||
current_logprobs = Llama.logits_to_logprobs(logits)
|
current_logprobs = Llama.logits_to_logprobs(logits)
|
||||||
sorted_logprobs = list(
|
sorted_logprobs = list(
|
||||||
sorted(
|
sorted(
|
||||||
|
@ -856,7 +879,7 @@ class Llama:
|
||||||
break
|
break
|
||||||
|
|
||||||
if stopping_criteria is not None and stopping_criteria(
|
if stopping_criteria is not None and stopping_criteria(
|
||||||
list(self.eval_tokens), self.eval_logits[-1]
|
self._input_ids.tolist(), self._scores[-1, :].tolist()
|
||||||
):
|
):
|
||||||
text = self.detokenize(completion_tokens)
|
text = self.detokenize(completion_tokens)
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
|
@ -886,7 +909,7 @@ class Llama:
|
||||||
self.detokenize(completion_tokens[:returned_tokens])
|
self.detokenize(completion_tokens[:returned_tokens])
|
||||||
)
|
)
|
||||||
token_offset = len(prompt_tokens) + returned_tokens - 1
|
token_offset = len(prompt_tokens) + returned_tokens - 1
|
||||||
logits = self.eval_logits[token_offset]
|
logits = self._scores[token_offset, :].tolist()
|
||||||
current_logprobs = Llama.logits_to_logprobs(logits)
|
current_logprobs = Llama.logits_to_logprobs(logits)
|
||||||
sorted_logprobs = list(
|
sorted_logprobs = list(
|
||||||
sorted(
|
sorted(
|
||||||
|
@ -988,8 +1011,7 @@ class Llama:
|
||||||
for token in all_tokens
|
for token in all_tokens
|
||||||
]
|
]
|
||||||
all_logprobs = [
|
all_logprobs = [
|
||||||
Llama.logits_to_logprobs(list(map(float, row)))
|
Llama.logits_to_logprobs(row.tolist()) for row in self._scores
|
||||||
for row in self.eval_logits
|
|
||||||
][token_offset:]
|
][token_offset:]
|
||||||
for token, token_str, logprobs_token in zip(
|
for token, token_str, logprobs_token in zip(
|
||||||
all_tokens, all_token_strs, all_logprobs
|
all_tokens, all_token_strs, all_logprobs
|
||||||
|
@ -1373,6 +1395,8 @@ class Llama:
|
||||||
return LlamaState(
|
return LlamaState(
|
||||||
eval_tokens=self.eval_tokens.copy(),
|
eval_tokens=self.eval_tokens.copy(),
|
||||||
eval_logits=self.eval_logits.copy(),
|
eval_logits=self.eval_logits.copy(),
|
||||||
|
scores=self._scores.copy(),
|
||||||
|
input_ids=self._input_ids.copy(),
|
||||||
llama_state=llama_state_compact,
|
llama_state=llama_state_compact,
|
||||||
llama_state_size=n_bytes,
|
llama_state_size=n_bytes,
|
||||||
)
|
)
|
||||||
|
@ -1381,6 +1405,8 @@ class Llama:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
self.eval_tokens = state.eval_tokens.copy()
|
self.eval_tokens = state.eval_tokens.copy()
|
||||||
self.eval_logits = state.eval_logits.copy()
|
self.eval_logits = state.eval_logits.copy()
|
||||||
|
self._scores = state.scores.copy()
|
||||||
|
self._input_ids = state.input_ids.copy()
|
||||||
state_size = state.llama_state_size
|
state_size = state.llama_state_size
|
||||||
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
|
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
|
||||||
raise RuntimeError("Failed to set llama state data")
|
raise RuntimeError("Failed to set llama state data")
|
||||||
|
|
4
setup.py
4
setup.py
|
@ -16,9 +16,7 @@ setup(
|
||||||
license="MIT",
|
license="MIT",
|
||||||
package_dir={"llama_cpp": "llama_cpp", "llama_cpp.server": "llama_cpp/server"},
|
package_dir={"llama_cpp": "llama_cpp", "llama_cpp.server": "llama_cpp/server"},
|
||||||
packages=["llama_cpp", "llama_cpp.server"],
|
packages=["llama_cpp", "llama_cpp.server"],
|
||||||
install_requires=[
|
install_requires=["typing-extensions>=4.5.0", "numpy>=1.20.0"],
|
||||||
"typing-extensions>=4.5.0",
|
|
||||||
],
|
|
||||||
extras_require={
|
extras_require={
|
||||||
"server": ["uvicorn>=0.21.1", "fastapi>=0.95.0", "sse-starlette>=1.3.3"],
|
"server": ["uvicorn>=0.21.1", "fastapi>=0.95.0", "sse-starlette>=1.3.3"],
|
||||||
},
|
},
|
||||||
|
|
Loading…
Reference in a new issue