Replace eval_logits and eval_tokens with numpy arrays

This commit is contained in:
Andrei Betlen 2023-05-26 20:03:31 -04:00
parent efb763bcdc
commit fe331ec589

View file

@ -299,6 +299,8 @@ class Llama:
"""Reset the model state."""
self.eval_tokens.clear()
self.eval_logits.clear()
self._input_ids = np.array([], dtype=np.intc)
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
def eval(self, tokens: Sequence[int]):
"""Evaluate a list of tokens.
@ -310,7 +312,7 @@ class Llama:
n_ctx = self._n_ctx
for i in range(0, len(tokens), self.n_batch):
batch = tokens[i : min(len(tokens), i + self.n_batch)]
n_past = min(n_ctx - len(batch), len(self.eval_tokens))
n_past = min(n_ctx - len(batch), len(self._input_ids))
n_tokens = len(batch)
return_code = llama_cpp.llama_eval(
ctx=self.ctx,
@ -356,6 +358,7 @@ class Llama:
):
assert self.ctx is not None
assert len(self.eval_logits) > 0
assert self._scores.shape[0] > 0
n_vocab = self._n_vocab
n_ctx = self._n_ctx
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
@ -368,7 +371,7 @@ class Llama:
if logits_processor is not None:
logits = np.array(
logits_processor(list(self.eval_tokens), logits.tolist()),
logits_processor(self._input_ids.tolist(), logits.tolist()),
dtype=np.single,
)
self._scores[-1, :] = logits
@ -498,8 +501,8 @@ class Llama:
"""
assert self.ctx is not None
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
0, self.last_n_tokens_size - len(self.eval_tokens)
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
0, self.last_n_tokens_size - len(self._input_ids)
) + self._input_ids[-self.last_n_tokens_size :].tolist()
return self._sample(
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
*last_n_tokens_data
@ -557,9 +560,9 @@ class Llama:
"""
assert self.ctx is not None
if reset and len(self.eval_tokens) > 0:
if reset and len(self._input_ids) > 0:
longest_prefix = 0
for a, b in zip(self.eval_tokens, tokens[:-1]):
for a, b in zip(self._input_ids, tokens[:-1]):
if a == b:
longest_prefix += 1
else:
@ -569,6 +572,8 @@ class Llama:
print("Llama.generate: prefix-match hit", file=sys.stderr)
reset = False
tokens = tokens[longest_prefix:]
self._input_ids = self._input_ids[:longest_prefix]
self._scores = self._scores[:longest_prefix, :]
for _ in range(len(self.eval_tokens) - longest_prefix):
self.eval_tokens.pop()
try:
@ -595,7 +600,7 @@ class Llama:
logits_processor=logits_processor,
)
if stopping_criteria is not None and stopping_criteria(
list(self.eval_tokens), self.eval_logits[-1]
self._input_ids.tolist(), self._scores[-1, :].tolist()
):
return
tokens_or_none = yield token
@ -820,7 +825,7 @@ class Llama:
self.detokenize(completion_tokens[:returned_tokens])
)
token_offset = len(prompt_tokens) + returned_tokens
logits = self.eval_logits[token_offset - 1]
logits = self._scores[token_offset - 1, :].tolist()
current_logprobs = Llama.logits_to_logprobs(logits)
sorted_logprobs = list(
sorted(
@ -869,7 +874,7 @@ class Llama:
break
if stopping_criteria is not None and stopping_criteria(
list(self.eval_tokens), self.eval_logits[-1]
self._input_ids.tolist(), self._scores[-1, :].tolist()
):
text = self.detokenize(completion_tokens)
finish_reason = "stop"
@ -899,7 +904,7 @@ class Llama:
self.detokenize(completion_tokens[:returned_tokens])
)
token_offset = len(prompt_tokens) + returned_tokens - 1
logits = self.eval_logits[token_offset]
logits = self._scores[token_offset, :].tolist()
current_logprobs = Llama.logits_to_logprobs(logits)
sorted_logprobs = list(
sorted(
@ -1001,8 +1006,7 @@ class Llama:
for token in all_tokens
]
all_logprobs = [
Llama.logits_to_logprobs(list(map(float, row)))
for row in self.eval_logits
Llama.logits_to_logprobs(row.tolist()) for row in self._scores
][token_offset:]
for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs