Add support for numpy
This commit is contained in:
parent
4c1b7f7a76
commit
8eb9769f78
2 changed files with 39 additions and 22 deletions
|
@ -20,6 +20,9 @@ from collections import deque, OrderedDict
|
|||
from . import llama_cpp
|
||||
from .llama_types import *
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
|
||||
class LlamaCache:
|
||||
"""Cache for a llama.cpp model."""
|
||||
|
@ -73,11 +76,15 @@ class LlamaState:
|
|||
self,
|
||||
eval_tokens: Deque[int],
|
||||
eval_logits: Deque[List[float]],
|
||||
input_ids: npt.NDArray[np.intc],
|
||||
scores: npt.NDArray[np.single],
|
||||
llama_state, # type: llama_cpp.Array[llama_cpp.c_uint8]
|
||||
llama_state_size: int,
|
||||
):
|
||||
self.eval_tokens = eval_tokens
|
||||
self.eval_logits = eval_logits
|
||||
self.input_ids = input_ids
|
||||
self.scores = scores
|
||||
self.llama_state = llama_state
|
||||
self.llama_state_size = llama_state_size
|
||||
|
||||
|
@ -207,20 +214,14 @@ class Llama:
|
|||
|
||||
self._n_vocab = self.n_vocab()
|
||||
self._n_ctx = self.n_ctx()
|
||||
data = (llama_cpp.llama_token_data * self._n_vocab)(
|
||||
*[
|
||||
llama_cpp.llama_token_data(
|
||||
id=llama_cpp.llama_token(i),
|
||||
logit=llama_cpp.c_float(0.0),
|
||||
p=llama_cpp.c_float(0.0),
|
||||
)
|
||||
for i in range(self._n_vocab)
|
||||
]
|
||||
)
|
||||
size = llama_cpp.c_size_t(self._n_vocab)
|
||||
sorted = False
|
||||
sorted = llama_cpp.c_bool(False)
|
||||
self._candidates_data = np.array(
|
||||
[], dtype=[("id", np.intc), ("logit", np.single), ("p", np.single)]
|
||||
)
|
||||
self._candidates_data.resize(3, self._n_vocab)
|
||||
candidates = llama_cpp.llama_token_data_array(
|
||||
data=data,
|
||||
data=self._candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p),
|
||||
size=size,
|
||||
sorted=sorted,
|
||||
)
|
||||
|
@ -228,6 +229,9 @@ class Llama:
|
|||
self._token_nl = Llama.token_nl()
|
||||
self._token_eos = Llama.token_eos()
|
||||
|
||||
self._input_ids = np.array([], dtype=np.intc)
|
||||
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
|
||||
|
||||
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
|
||||
"""Tokenize a string.
|
||||
|
||||
|
@ -319,6 +323,9 @@ class Llama:
|
|||
raise RuntimeError(f"llama_eval returned {return_code}")
|
||||
# Save tokens
|
||||
self.eval_tokens.extend(batch)
|
||||
self._input_ids: npt.NDArray[np.intc] = np.concatenate(
|
||||
(self._input_ids, np.array(batch, dtype=np.intc)), axis=0
|
||||
)
|
||||
# Save logits
|
||||
rows = n_tokens if self.params.logits_all else 1
|
||||
n_vocab = self._n_vocab
|
||||
|
@ -326,6 +333,9 @@ class Llama:
|
|||
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
||||
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
|
||||
self.eval_logits.extend(logits)
|
||||
self._scores: npt.NDArray[np.single] = np.concatenate(
|
||||
(self._scores, np.array(logits, dtype=np.single)), axis=0
|
||||
)
|
||||
|
||||
def _sample(
|
||||
self,
|
||||
|
@ -354,18 +364,23 @@ class Llama:
|
|||
if last_n_tokens_size.value < 0
|
||||
else last_n_tokens_size
|
||||
)
|
||||
logits = self.eval_logits[-1]
|
||||
logits: npt.NDArray[np.single] = self._scores[-1, :]
|
||||
|
||||
if logits_processor is not None:
|
||||
logits = logits_processor(list(self.eval_tokens), logits)
|
||||
self.eval_logits[-1] = logits
|
||||
logits = np.array(
|
||||
logits_processor(list(self.eval_tokens), logits.tolist()),
|
||||
dtype=np.single,
|
||||
)
|
||||
self._scores[-1, :] = logits
|
||||
self.eval_logits[-1] = logits.tolist()
|
||||
|
||||
nl_logit = logits[self._token_nl]
|
||||
candidates = self._candidates
|
||||
for i, logit in enumerate(logits):
|
||||
candidates.data[i].id = llama_cpp.llama_token(i)
|
||||
candidates.data[i].logit = llama_cpp.c_float(logit)
|
||||
candidates.data[i].p = llama_cpp.c_float(0.0)
|
||||
candidates_data = self._candidates_data
|
||||
candidates_data["id"] = np.arange(n_vocab, dtype=np.intc) # type: ignore
|
||||
candidates_data["logit"] = logits
|
||||
candidates_data["p"] = np.zeros(n_vocab, dtype=np.single)
|
||||
candidates.data = candidates_data.ctypes.data_as(llama_cpp.llama_token_data_p)
|
||||
candidates.sorted = llama_cpp.c_bool(False)
|
||||
candidates.size = llama_cpp.c_size_t(n_vocab)
|
||||
llama_cpp.llama_sample_repetition_penalty(
|
||||
|
@ -1371,6 +1386,8 @@ class Llama:
|
|||
return LlamaState(
|
||||
eval_tokens=self.eval_tokens.copy(),
|
||||
eval_logits=self.eval_logits.copy(),
|
||||
scores=self._scores.copy(),
|
||||
input_ids=self._input_ids.copy(),
|
||||
llama_state=llama_state_compact,
|
||||
llama_state_size=n_bytes,
|
||||
)
|
||||
|
@ -1379,6 +1396,8 @@ class Llama:
|
|||
assert self.ctx is not None
|
||||
self.eval_tokens = state.eval_tokens.copy()
|
||||
self.eval_logits = state.eval_logits.copy()
|
||||
self._scores = state.scores.copy()
|
||||
self._input_ids = state.input_ids.copy()
|
||||
state_size = state.llama_state_size
|
||||
if llama_cpp.llama_set_state_data(self.ctx, state.llama_state) != state_size:
|
||||
raise RuntimeError("Failed to set llama state data")
|
||||
|
|
4
setup.py
4
setup.py
|
@ -16,9 +16,7 @@ setup(
|
|||
license="MIT",
|
||||
package_dir={"llama_cpp": "llama_cpp", "llama_cpp.server": "llama_cpp/server"},
|
||||
packages=["llama_cpp", "llama_cpp.server"],
|
||||
install_requires=[
|
||||
"typing-extensions>=4.5.0",
|
||||
],
|
||||
install_requires=["typing-extensions>=4.5.0", "numpy>=1.24.2"],
|
||||
extras_require={
|
||||
"server": ["uvicorn>=0.21.1", "fastapi>=0.95.0", "sse-starlette>=1.3.3"],
|
||||
},
|
||||
|
|
Loading…
Reference in a new issue