Replace eval_logits and eval_tokens with numpy arrays
This commit is contained in:
parent
efb763bcdc
commit
fe331ec589
1 changed files with 16 additions and 12 deletions
|
@ -299,6 +299,8 @@ class Llama:
|
|||
"""Reset the model state."""
|
||||
self.eval_tokens.clear()
|
||||
self.eval_logits.clear()
|
||||
self._input_ids = np.array([], dtype=np.intc)
|
||||
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
|
||||
|
||||
def eval(self, tokens: Sequence[int]):
|
||||
"""Evaluate a list of tokens.
|
||||
|
@ -310,7 +312,7 @@ class Llama:
|
|||
n_ctx = self._n_ctx
|
||||
for i in range(0, len(tokens), self.n_batch):
|
||||
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
||||
n_past = min(n_ctx - len(batch), len(self.eval_tokens))
|
||||
n_past = min(n_ctx - len(batch), len(self._input_ids))
|
||||
n_tokens = len(batch)
|
||||
return_code = llama_cpp.llama_eval(
|
||||
ctx=self.ctx,
|
||||
|
@ -356,6 +358,7 @@ class Llama:
|
|||
):
|
||||
assert self.ctx is not None
|
||||
assert len(self.eval_logits) > 0
|
||||
assert self._scores.shape[0] > 0
|
||||
n_vocab = self._n_vocab
|
||||
n_ctx = self._n_ctx
|
||||
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
|
||||
|
@ -368,7 +371,7 @@ class Llama:
|
|||
|
||||
if logits_processor is not None:
|
||||
logits = np.array(
|
||||
logits_processor(list(self.eval_tokens), logits.tolist()),
|
||||
logits_processor(self._input_ids.tolist(), logits.tolist()),
|
||||
dtype=np.single,
|
||||
)
|
||||
self._scores[-1, :] = logits
|
||||
|
@ -498,8 +501,8 @@ class Llama:
|
|||
"""
|
||||
assert self.ctx is not None
|
||||
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
|
||||
0, self.last_n_tokens_size - len(self.eval_tokens)
|
||||
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
|
||||
0, self.last_n_tokens_size - len(self._input_ids)
|
||||
) + self._input_ids[-self.last_n_tokens_size :].tolist()
|
||||
return self._sample(
|
||||
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
|
||||
*last_n_tokens_data
|
||||
|
@ -557,9 +560,9 @@ class Llama:
|
|||
"""
|
||||
assert self.ctx is not None
|
||||
|
||||
if reset and len(self.eval_tokens) > 0:
|
||||
if reset and len(self._input_ids) > 0:
|
||||
longest_prefix = 0
|
||||
for a, b in zip(self.eval_tokens, tokens[:-1]):
|
||||
for a, b in zip(self._input_ids, tokens[:-1]):
|
||||
if a == b:
|
||||
longest_prefix += 1
|
||||
else:
|
||||
|
@ -569,6 +572,8 @@ class Llama:
|
|||
print("Llama.generate: prefix-match hit", file=sys.stderr)
|
||||
reset = False
|
||||
tokens = tokens[longest_prefix:]
|
||||
self._input_ids = self._input_ids[:longest_prefix]
|
||||
self._scores = self._scores[:longest_prefix, :]
|
||||
for _ in range(len(self.eval_tokens) - longest_prefix):
|
||||
self.eval_tokens.pop()
|
||||
try:
|
||||
|
@ -595,7 +600,7 @@ class Llama:
|
|||
logits_processor=logits_processor,
|
||||
)
|
||||
if stopping_criteria is not None and stopping_criteria(
|
||||
list(self.eval_tokens), self.eval_logits[-1]
|
||||
self._input_ids.tolist(), self._scores[-1, :].tolist()
|
||||
):
|
||||
return
|
||||
tokens_or_none = yield token
|
||||
|
@ -820,7 +825,7 @@ class Llama:
|
|||
self.detokenize(completion_tokens[:returned_tokens])
|
||||
)
|
||||
token_offset = len(prompt_tokens) + returned_tokens
|
||||
logits = self.eval_logits[token_offset - 1]
|
||||
logits = self._scores[token_offset - 1, :].tolist()
|
||||
current_logprobs = Llama.logits_to_logprobs(logits)
|
||||
sorted_logprobs = list(
|
||||
sorted(
|
||||
|
@ -869,7 +874,7 @@ class Llama:
|
|||
break
|
||||
|
||||
if stopping_criteria is not None and stopping_criteria(
|
||||
list(self.eval_tokens), self.eval_logits[-1]
|
||||
self._input_ids.tolist(), self._scores[-1, :].tolist()
|
||||
):
|
||||
text = self.detokenize(completion_tokens)
|
||||
finish_reason = "stop"
|
||||
|
@ -899,7 +904,7 @@ class Llama:
|
|||
self.detokenize(completion_tokens[:returned_tokens])
|
||||
)
|
||||
token_offset = len(prompt_tokens) + returned_tokens - 1
|
||||
logits = self.eval_logits[token_offset]
|
||||
logits = self._scores[token_offset, :].tolist()
|
||||
current_logprobs = Llama.logits_to_logprobs(logits)
|
||||
sorted_logprobs = list(
|
||||
sorted(
|
||||
|
@ -1001,8 +1006,7 @@ class Llama:
|
|||
for token in all_tokens
|
||||
]
|
||||
all_logprobs = [
|
||||
Llama.logits_to_logprobs(list(map(float, row)))
|
||||
for row in self.eval_logits
|
||||
Llama.logits_to_logprobs(row.tolist()) for row in self._scores
|
||||
][token_offset:]
|
||||
for token, token_str, logprobs_token in zip(
|
||||
all_tokens, all_token_strs, all_logprobs
|
||||
|
|
Loading…
Reference in a new issue