Replace eval_logits and eval_tokens with numpy arrays
This commit is contained in:
parent
efb763bcdc
commit
fe331ec589
1 changed files with 16 additions and 12 deletions
|
@ -299,6 +299,8 @@ class Llama:
|
||||||
"""Reset the model state."""
|
"""Reset the model state."""
|
||||||
self.eval_tokens.clear()
|
self.eval_tokens.clear()
|
||||||
self.eval_logits.clear()
|
self.eval_logits.clear()
|
||||||
|
self._input_ids = np.array([], dtype=np.intc)
|
||||||
|
self._scores = np.ndarray((0, self._n_vocab), dtype=np.single)
|
||||||
|
|
||||||
def eval(self, tokens: Sequence[int]):
|
def eval(self, tokens: Sequence[int]):
|
||||||
"""Evaluate a list of tokens.
|
"""Evaluate a list of tokens.
|
||||||
|
@ -310,7 +312,7 @@ class Llama:
|
||||||
n_ctx = self._n_ctx
|
n_ctx = self._n_ctx
|
||||||
for i in range(0, len(tokens), self.n_batch):
|
for i in range(0, len(tokens), self.n_batch):
|
||||||
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
batch = tokens[i : min(len(tokens), i + self.n_batch)]
|
||||||
n_past = min(n_ctx - len(batch), len(self.eval_tokens))
|
n_past = min(n_ctx - len(batch), len(self._input_ids))
|
||||||
n_tokens = len(batch)
|
n_tokens = len(batch)
|
||||||
return_code = llama_cpp.llama_eval(
|
return_code = llama_cpp.llama_eval(
|
||||||
ctx=self.ctx,
|
ctx=self.ctx,
|
||||||
|
@ -356,6 +358,7 @@ class Llama:
|
||||||
):
|
):
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
assert len(self.eval_logits) > 0
|
assert len(self.eval_logits) > 0
|
||||||
|
assert self._scores.shape[0] > 0
|
||||||
n_vocab = self._n_vocab
|
n_vocab = self._n_vocab
|
||||||
n_ctx = self._n_ctx
|
n_ctx = self._n_ctx
|
||||||
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
|
top_k = llama_cpp.c_int(n_vocab) if top_k.value <= 0 else top_k
|
||||||
|
@ -368,7 +371,7 @@ class Llama:
|
||||||
|
|
||||||
if logits_processor is not None:
|
if logits_processor is not None:
|
||||||
logits = np.array(
|
logits = np.array(
|
||||||
logits_processor(list(self.eval_tokens), logits.tolist()),
|
logits_processor(self._input_ids.tolist(), logits.tolist()),
|
||||||
dtype=np.single,
|
dtype=np.single,
|
||||||
)
|
)
|
||||||
self._scores[-1, :] = logits
|
self._scores[-1, :] = logits
|
||||||
|
@ -498,8 +501,8 @@ class Llama:
|
||||||
"""
|
"""
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
|
last_n_tokens_data = [llama_cpp.llama_token(0)] * max(
|
||||||
0, self.last_n_tokens_size - len(self.eval_tokens)
|
0, self.last_n_tokens_size - len(self._input_ids)
|
||||||
) + list(self.eval_tokens)[-self.last_n_tokens_size :]
|
) + self._input_ids[-self.last_n_tokens_size :].tolist()
|
||||||
return self._sample(
|
return self._sample(
|
||||||
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
|
last_n_tokens_data=(llama_cpp.llama_token * self.last_n_tokens_size)(
|
||||||
*last_n_tokens_data
|
*last_n_tokens_data
|
||||||
|
@ -557,9 +560,9 @@ class Llama:
|
||||||
"""
|
"""
|
||||||
assert self.ctx is not None
|
assert self.ctx is not None
|
||||||
|
|
||||||
if reset and len(self.eval_tokens) > 0:
|
if reset and len(self._input_ids) > 0:
|
||||||
longest_prefix = 0
|
longest_prefix = 0
|
||||||
for a, b in zip(self.eval_tokens, tokens[:-1]):
|
for a, b in zip(self._input_ids, tokens[:-1]):
|
||||||
if a == b:
|
if a == b:
|
||||||
longest_prefix += 1
|
longest_prefix += 1
|
||||||
else:
|
else:
|
||||||
|
@ -569,6 +572,8 @@ class Llama:
|
||||||
print("Llama.generate: prefix-match hit", file=sys.stderr)
|
print("Llama.generate: prefix-match hit", file=sys.stderr)
|
||||||
reset = False
|
reset = False
|
||||||
tokens = tokens[longest_prefix:]
|
tokens = tokens[longest_prefix:]
|
||||||
|
self._input_ids = self._input_ids[:longest_prefix]
|
||||||
|
self._scores = self._scores[:longest_prefix, :]
|
||||||
for _ in range(len(self.eval_tokens) - longest_prefix):
|
for _ in range(len(self.eval_tokens) - longest_prefix):
|
||||||
self.eval_tokens.pop()
|
self.eval_tokens.pop()
|
||||||
try:
|
try:
|
||||||
|
@ -595,7 +600,7 @@ class Llama:
|
||||||
logits_processor=logits_processor,
|
logits_processor=logits_processor,
|
||||||
)
|
)
|
||||||
if stopping_criteria is not None and stopping_criteria(
|
if stopping_criteria is not None and stopping_criteria(
|
||||||
list(self.eval_tokens), self.eval_logits[-1]
|
self._input_ids.tolist(), self._scores[-1, :].tolist()
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
tokens_or_none = yield token
|
tokens_or_none = yield token
|
||||||
|
@ -820,7 +825,7 @@ class Llama:
|
||||||
self.detokenize(completion_tokens[:returned_tokens])
|
self.detokenize(completion_tokens[:returned_tokens])
|
||||||
)
|
)
|
||||||
token_offset = len(prompt_tokens) + returned_tokens
|
token_offset = len(prompt_tokens) + returned_tokens
|
||||||
logits = self.eval_logits[token_offset - 1]
|
logits = self._scores[token_offset - 1, :].tolist()
|
||||||
current_logprobs = Llama.logits_to_logprobs(logits)
|
current_logprobs = Llama.logits_to_logprobs(logits)
|
||||||
sorted_logprobs = list(
|
sorted_logprobs = list(
|
||||||
sorted(
|
sorted(
|
||||||
|
@ -869,7 +874,7 @@ class Llama:
|
||||||
break
|
break
|
||||||
|
|
||||||
if stopping_criteria is not None and stopping_criteria(
|
if stopping_criteria is not None and stopping_criteria(
|
||||||
list(self.eval_tokens), self.eval_logits[-1]
|
self._input_ids.tolist(), self._scores[-1, :].tolist()
|
||||||
):
|
):
|
||||||
text = self.detokenize(completion_tokens)
|
text = self.detokenize(completion_tokens)
|
||||||
finish_reason = "stop"
|
finish_reason = "stop"
|
||||||
|
@ -899,7 +904,7 @@ class Llama:
|
||||||
self.detokenize(completion_tokens[:returned_tokens])
|
self.detokenize(completion_tokens[:returned_tokens])
|
||||||
)
|
)
|
||||||
token_offset = len(prompt_tokens) + returned_tokens - 1
|
token_offset = len(prompt_tokens) + returned_tokens - 1
|
||||||
logits = self.eval_logits[token_offset]
|
logits = self._scores[token_offset, :].tolist()
|
||||||
current_logprobs = Llama.logits_to_logprobs(logits)
|
current_logprobs = Llama.logits_to_logprobs(logits)
|
||||||
sorted_logprobs = list(
|
sorted_logprobs = list(
|
||||||
sorted(
|
sorted(
|
||||||
|
@ -1001,8 +1006,7 @@ class Llama:
|
||||||
for token in all_tokens
|
for token in all_tokens
|
||||||
]
|
]
|
||||||
all_logprobs = [
|
all_logprobs = [
|
||||||
Llama.logits_to_logprobs(list(map(float, row)))
|
Llama.logits_to_logprobs(row.tolist()) for row in self._scores
|
||||||
for row in self.eval_logits
|
|
||||||
][token_offset:]
|
][token_offset:]
|
||||||
for token, token_str, logprobs_token in zip(
|
for token, token_str, logprobs_token in zip(
|
||||||
all_tokens, all_token_strs, all_logprobs
|
all_tokens, all_token_strs, all_logprobs
|
||||||
|
|
Loading…
Reference in a new issue