Replace eval_logits and eval_tokens with numpy arrays

This commit is contained in:
Andrei Betlen 2023-05-26 20:03:31 -04:00
parent efb763bcdc
commit fe331ec589

View file

@ -299,6 +299,8 @@ class Llama:
"""Reset the model state.""" """Reset the model state."""
self.eval_tokens.clear() self.eval_tokens.clear()
self.eval_logits.clear() self.eval_logits.clear()
self._input_ids = np.array([], dtype=np.intc)
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
def eval(self, tokens: Sequence[int]): def eval(self, tokens: Sequence[int]):
"""Evaluate a list of tokens. """Evaluate a list of tokens.
@ -310,7 +312,7 @@ class Llama:
n_ctx = self._n_ctx n_ctx = self._n_ctx
for i in range(0, len(tokens), self.n_batch): for i in range(0, len(tokens), self.n_batch):
batch = tokens[i : min(len(tokens), i + self.n_batch)] batch = tokens[i : min(len(tokens), i + self.n_batch)]
n_past = min(n_ctx - len(batch), len(self.eval_tokens)) n_past = min(n_ctx - len(batch), len(self._input_ids))
n_tokens = len(batch) n_tokens = len(batch)
return_code = llama_cpp.llama_eval( return_code = llama_cpp.llama_eval(
ctx=self.ctx, ctx=self.ctx,
@ -356,6 +358,7 @@ class Llama:
): ):
assert self.ctx is not None assert self.ctx is not None
assert len(self.eval_logits) > 0 assert len(self.eval_logits) > 0
assert self._scores.shape[0] > 0
n_vocab = self._n_vocab n_vocab = self._n_vocab
n_ctx = self._n_ctx n_ctx = self._n_ctx
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
@ -368,7 +371,7 @@ class Llama:
if logits_processor is not None: if logits_processor is not None:
logits = np.array( logits = np.array(
logits_processor(list(self.eval_tokens), logits.tolist()), logits_processor(self._input_ids.tolist(), logits.tolist()),
dtype=np.single, dtype=np.single,
) )
self._scores[-1, :] = logits self._scores[-1, :] = logits
@ -498,8 +501,8 @@ class Llama:
""" """
assert self.ctx is not None assert self.ctx is not None
last_n_tokens_data = [llama_cpp.llama_token(0)] * max( last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
0, self.last_n_tokens_size - len(self.eval_tokens) 0, self.last_n_tokens_size - len(self._input_ids)
) + list(self.eval_tokens)[-self.last_n_tokens_size :] ) + self._input_ids[-self.last_n_tokens_size :].tolist()
return self._sample( return self._sample(
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)( last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
*last_n_tokens_data *last_n_tokens_data
@ -557,9 +560,9 @@ class Llama:
""" """
assert self.ctx is not None assert self.ctx is not None
if reset and len(self.eval_tokens) > 0: if reset and len(self._input_ids) > 0:
longest_prefix = 0 longest_prefix = 0
for a, b in zip(self.eval_tokens, tokens[:-1]): for a, b in zip(self._input_ids, tokens[:-1]):
if a == b: if a == b:
longest_prefix += 1 longest_prefix += 1
else: else:
@ -569,6 +572,8 @@ class Llama:
print("Llama.generate: prefix-match hit", file=sys.stderr) print("Llama.generate: prefix-match hit", file=sys.stderr)
reset = False reset = False
tokens = tokens[longest_prefix:] tokens = tokens[longest_prefix:]
self._input_ids = self._input_ids[:longest_prefix]
self._scores = self._scores[:longest_prefix, :]
for _ in range(len(self.eval_tokens) - longest_prefix): for _ in range(len(self.eval_tokens) - longest_prefix):
self.eval_tokens.pop() self.eval_tokens.pop()
try: try:
@ -595,7 +600,7 @@ class Llama:
logits_processor=logits_processor, logits_processor=logits_processor,
) )
if stopping_criteria is not None and stopping_criteria( if stopping_criteria is not None and stopping_criteria(
list(self.eval_tokens), self.eval_logits[-1] self._input_ids.tolist(), self._scores[-1, :].tolist()
): ):
return return
tokens_or_none = yield token tokens_or_none = yield token
@ -820,7 +825,7 @@ class Llama:
self.detokenize(completion_tokens[:returned_tokens]) self.detokenize(completion_tokens[:returned_tokens])
) )
token_offset = len(prompt_tokens) + returned_tokens token_offset = len(prompt_tokens) + returned_tokens
logits = self.eval_logits[token_offset - 1] logits = self._scores[token_offset - 1, :].tolist()
current_logprobs = Llama.logits_to_logprobs(logits) current_logprobs = Llama.logits_to_logprobs(logits)
sorted_logprobs = list( sorted_logprobs = list(
sorted( sorted(
@ -869,7 +874,7 @@ class Llama:
break break
if stopping_criteria is not None and stopping_criteria( if stopping_criteria is not None and stopping_criteria(
list(self.eval_tokens), self.eval_logits[-1] self._input_ids.tolist(), self._scores[-1, :].tolist()
): ):
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
finish_reason = "stop" finish_reason = "stop"
@ -899,7 +904,7 @@ class Llama:
self.detokenize(completion_tokens[:returned_tokens]) self.detokenize(completion_tokens[:returned_tokens])
) )
token_offset = len(prompt_tokens) + returned_tokens - 1 token_offset = len(prompt_tokens) + returned_tokens - 1
logits = self.eval_logits[token_offset] logits = self._scores[token_offset, :].tolist()
current_logprobs = Llama.logits_to_logprobs(logits) current_logprobs = Llama.logits_to_logprobs(logits)
sorted_logprobs = list( sorted_logprobs = list(
sorted( sorted(
@ -1001,8 +1006,7 @@ class Llama:
for token in all_tokens for token in all_tokens
] ]
all_logprobs = [ all_logprobs = [
Llama.logits_to_logprobs(list(map(float, row))) Llama.logits_to_logprobs(row.tolist()) for row in self._scores
for row in self.eval_logits
][token_offset:] ][token_offset:]
for token, token_str, logprobs_token in zip( for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs all_tokens, all_token_strs, all_logprobs