Use pre-allocated buffers to store input_ids and scores
This commit is contained in:
parent
a5e059c053
commit
b95b0ffbeb
1 changed files with 44 additions and 42 deletions
|
@ -141,7 +141,7 @@ class LlamaDiskCache(BaseLlamaCache):
|
||||||
if _key is None:
|
if _key is None:
|
||||||
raise KeyError("Key not found")
|
raise KeyError("Key not found")
|
||||||
value: "LlamaState" = self.cache.pop(_key) # type: ignore
|
value: "LlamaState" = self.cache.pop(_key) # type: ignore
|
||||||
# NOTE: This puts an integer as key in cache, which breaks,
|
# NOTE: This puts an integer as key in cache, which breaks,
|
||||||
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
|
# Llama.longest_token_prefix(k, key) above since k is not a tuple of ints/tokens
|
||||||
# self.cache.push(_key, side="front") # type: ignore
|
# self.cache.push(_key, side="front") # type: ignore
|
||||||
return value
|
return value
|
||||||
|
@ -166,17 +166,15 @@ class LlamaDiskCache(BaseLlamaCache):
|
||||||
class LlamaState:
|
class LlamaState:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
eval_tokens: Deque[int],
|
|
||||||
eval_logits: Deque[List[float]],
|
|
||||||
input_ids: npt.NDArray[np.intc],
|
input_ids: npt.NDArray[np.intc],
|
||||||
scores: npt.NDArray[np.single],
|
scores: npt.NDArray[np.single],
|
||||||
|
n_tokens: int,
|
||||||
llama_state: bytes,
|
llama_state: bytes,
|
||||||
llama_state_size: int,
|
llama_state_size: int,
|
||||||
):
|
):
|
||||||
self.eval_tokens = eval_tokens
|
|
||||||
self.eval_logits = eval_logits
|
|
||||||
self.input_ids = input_ids
|
self.input_ids = input_ids
|
||||||
self.scores = scores
|
self.scores = scores
|
||||||
|
self.n_tokens = n_tokens
|
||||||
self.llama_state = llama_state
|
self.llama_state = llama_state
|
||||||
self.llama_state_size = llama_state_size
|
self.llama_state_size = llama_state_size
|
||||||
|
|
||||||
|
@ -267,8 +265,6 @@ class Llama:
|
||||||
|
|
||||||
self.last_n_tokens_size = last_n_tokens_size
|
self.last_n_tokens_size = last_n_tokens_size
|
||||||
self.n_batch = min(n_ctx, n_batch)
|
self.n_batch = min(n_ctx, n_batch)
|
||||||
self.eval_tokens: Deque[int] = deque(maxlen=n_ctx)
|
|
||||||
self.eval_logits: Deque[List[float]] = deque(maxlen=n_ctx if logits_all else 1)
|
|
||||||
|
|
||||||
self.cache: Optional[BaseLlamaCache] = None
|
self.cache: Optional[BaseLlamaCache] = None
|
||||||
|
|
||||||
|
@ -329,8 +325,30 @@ class Llama:
|
||||||
self._token_nl = Llama.token_nl()
|
self._token_nl = Llama.token_nl()
|
||||||
self._token_eos = Llama.token_eos()
|
self._token_eos = Llama.token_eos()
|
||||||
|
|
||||||
self._input_ids = np.array([], dtype=np.intc)
|
self.n_tokens = 0
|
||||||
self._scores: npt.NDArray[np.single] = np.ndarray((0, self._n_vocab), dtype=np.single)
|
self.input_ids: npt.NDArray[np.intc] = np.ndarray((n_ctx,), dtype=np.intc)
|
||||||
|
self.scores: npt.NDArray[np.single] = np.ndarray(
|
||||||
|
(n_ctx, self._n_vocab), dtype=np.single
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _input_ids(self) -> npt.NDArray[np.intc]:
|
||||||
|
return self.input_ids[: self.n_tokens]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _scores(self) -> npt.NDArray[np.single]:
|
||||||
|
return self.scores[: self.n_tokens, :]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eval_tokens(self) -> Deque[int]:
|
||||||
|
return deque(self.input_ids[: self.n_tokens].tolist(), maxlen=self._n_ctx)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eval_logits(self) -> Deque[List[float]]:
|
||||||
|
return deque(
|
||||||
|
self.scores[: self.n_tokens, :].tolist(),
|
||||||
|
maxlen=self._n_ctx if self.params.logits_all else 1,
|
||||||
|
)
|
||||||
|
|
||||||
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
|
def tokenize(self, text: bytes, add_bos: bool = True) -> List[int]:
|
||||||
"""Tokenize a string.
|
"""Tokenize a string.
|
||||||
|
@ -397,10 +415,7 @@ class Llama:
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the model state."""
|
"""Reset the model state."""
|
||||||
self.eval_tokens.clear()
|
self.n_tokens = 0
|
||||||
self.eval_logits.clear()
|
|
||||||
self._input_ids = np.array([], dtype=np.intc)
|
|
||||||
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
|
|
||||||
|
|
||||||
def eval(self, tokens: Sequence[int]):
|
def eval(self, tokens: Sequence[int]):
|
||||||
"""Evaluate a list of tokens.
|
"""Evaluate a list of tokens.
|
||||||
|
@ -410,7 +425,6 @@ class Llama:
|
||||||
"""
|
"""
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
n_ctx = self._n_ctx
|
n_ctx = self._n_ctx
|
||||||
scores: List[npt.NDArray[np.single]] = []
|
|
||||||
for i in range(0, len(tokens), self.n_batch):
|
for i in range(0, len(tokens), self.n_batch):
|
||||||
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
||||||
n_past = min(n_ctx - len(batch), len(self._input_ids))
|
n_past = min(n_ctx - len(batch), len(self._input_ids))
|
||||||
|
@ -425,19 +439,16 @@ class Llama:
|
||||||
if return_code != 0:
|
if return_code != 0:
|
||||||
raise RuntimeError(f"llama_eval returned {return_code}")
|
raise RuntimeError(f"llama_eval returned {return_code}")
|
||||||
# Save tokens
|
# Save tokens
|
||||||
self.eval_tokens.extend(batch)
|
self.input_ids[self.n_tokens : self.n_tokens + n_tokens] = batch
|
||||||
self._input_ids: npt.NDArray[np.intc] = np.concatenate(
|
|
||||||
(self._input_ids, np.array(batch, dtype=np.intc)), axis=0
|
|
||||||
)
|
|
||||||
# Save logits
|
# Save logits
|
||||||
rows = n_tokens if self.params.logits_all else 1
|
rows = n_tokens if self.params.logits_all else 1
|
||||||
n_vocab = self._n_vocab
|
n_vocab = self._n_vocab
|
||||||
cols = n_vocab
|
cols = n_vocab
|
||||||
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
logits_view = llama_cpp.llama_get_logits(self.ctx)
|
||||||
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
|
logits = [logits_view[i * cols : (i + 1) * cols] for i in range(rows)]
|
||||||
self.eval_logits.extend(logits)
|
self.scores[self.n_tokens : self.n_tokens + n_tokens, :] = logits
|
||||||
scores.append(np.array(logits, dtype=np.single))
|
# Update n_tokens
|
||||||
self._scores = np.concatenate(scores)
|
self.n_tokens += n_tokens
|
||||||
|
|
||||||
def _sample(
|
def _sample(
|
||||||
self,
|
self,
|
||||||
|
@ -457,8 +468,7 @@ class Llama:
|
||||||
logits_processor: Optional[LogitsProcessorList] = None,
|
logits_processor: Optional[LogitsProcessorList] = None,
|
||||||
):
|
):
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
assert len(self.eval_logits) > 0
|
assert self.n_tokens > 0
|
||||||
assert self._scores.shape[0] > 0
|
|
||||||
n_vocab = self._n_vocab
|
n_vocab = self._n_vocab
|
||||||
n_ctx = self._n_ctx
|
n_ctx = self._n_ctx
|
||||||
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
|
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
|
||||||
|
@ -475,7 +485,6 @@ class Llama:
|
||||||
dtype=np.single,
|
dtype=np.single,
|
||||||
)
|
)
|
||||||
self._scores[-1, :] = logits
|
self._scores[-1, :] = logits
|
||||||
self.eval_logits[-1] = logits.tolist()
|
|
||||||
|
|
||||||
nl_logit = logits[self._token_nl]
|
nl_logit = logits[self._token_nl]
|
||||||
candidates = self._candidates
|
candidates = self._candidates
|
||||||
|
@ -672,14 +681,7 @@ class Llama:
|
||||||
print("Llama.generate: prefix-match hit", file=sys.stderr)
|
print("Llama.generate: prefix-match hit", file=sys.stderr)
|
||||||
reset = False
|
reset = False
|
||||||
tokens = tokens[longest_prefix:]
|
tokens = tokens[longest_prefix:]
|
||||||
self._input_ids = self._input_ids[:longest_prefix]
|
self.n_tokens = longest_prefix
|
||||||
self._scores = self._scores[:longest_prefix, :]
|
|
||||||
for _ in range(len(self.eval_tokens) - longest_prefix):
|
|
||||||
self.eval_tokens.pop()
|
|
||||||
try:
|
|
||||||
self.eval_logits.pop()
|
|
||||||
except IndexError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if reset:
|
if reset:
|
||||||
self.reset()
|
self.reset()
|
||||||
|
@ -819,7 +821,9 @@ class Llama:
|
||||||
llama_cpp.llama_reset_timings(self.ctx)
|
llama_cpp.llama_reset_timings(self.ctx)
|
||||||
|
|
||||||
if len(prompt_tokens) > self._n_ctx:
|
if len(prompt_tokens) > self._n_ctx:
|
||||||
raise ValueError(f"Requested tokens ({len(prompt_tokens)}) exceed context window of {self._n_ctx}")
|
raise ValueError(
|
||||||
|
f"Requested tokens ({len(prompt_tokens)}) exceed context window of {self._n_ctx}"
|
||||||
|
)
|
||||||
|
|
||||||
# Truncate max_tokens if requested tokens would exceed the context window
|
# Truncate max_tokens if requested tokens would exceed the context window
|
||||||
max_tokens = (
|
max_tokens = (
|
||||||
|
@ -1513,22 +1517,20 @@ class Llama:
|
||||||
file=sys.stderr,
|
file=sys.stderr,
|
||||||
)
|
)
|
||||||
return LlamaState(
|
return LlamaState(
|
||||||
eval_tokens=self.eval_tokens.copy(),
|
scores=self.scores.copy(),
|
||||||
eval_logits=self.eval_logits.copy(),
|
input_ids=self.input_ids.copy(),
|
||||||
scores=self._scores.copy(),
|
n_tokens=self.n_tokens,
|
||||||
input_ids=self._input_ids.copy(),
|
|
||||||
llama_state=bytes(llama_state_compact),
|
llama_state=bytes(llama_state_compact),
|
||||||
llama_state_size=n_bytes,
|
llama_state_size=n_bytes,
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_state(self, state: LlamaState) -> None:
|
def load_state(self, state: LlamaState) -> None:
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
self.eval_tokens = state.eval_tokens.copy()
|
self.scores = state.scores.copy()
|
||||||
self.eval_logits = state.eval_logits.copy()
|
self.input_ids = state.input_ids.copy()
|
||||||
self._scores = state.scores.copy()
|
self.n_tokens = state.n_tokens
|
||||||
self._input_ids = state.input_ids.copy()
|
|
||||||
state_size = state.llama_state_size
|
state_size = state.llama_state_size
|
||||||
LLamaStateArrayType = (llama_cpp.c_uint8 * state_size)
|
LLamaStateArrayType = llama_cpp.c_uint8 * state_size
|
||||||
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
|
llama_state = LLamaStateArrayType.from_buffer_copy(state.llama_state)
|
||||||
|
|
||||||
if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size:
|
if llama_cpp.llama_set_state_data(self.ctx, llama_state) != state_size:
|
||||||
|
|
Loading…
Add table
Reference in a new issue