perf: Don't convert logprobs arrays to lists (#1021)
This commit is contained in:
parent
62944df142
commit
6b2e0e05b4
1 changed files with 6 additions and 7 deletions
|
@ -1552,7 +1552,7 @@ class Llama:
|
|||
self.detokenize(completion_tokens[:returned_tokens])
|
||||
)
|
||||
token_offset = len(prompt_tokens) + returned_tokens
|
||||
logits = self._scores[token_offset - 1, :].tolist()
|
||||
logits = self._scores[token_offset - 1, :]
|
||||
current_logprobs = Llama.logits_to_logprobs(logits)
|
||||
sorted_logprobs = list(
|
||||
sorted(
|
||||
|
@ -1671,7 +1671,7 @@ class Llama:
|
|||
self.detokenize(completion_tokens[:returned_tokens])
|
||||
)
|
||||
token_offset = len(prompt_tokens) + returned_tokens - 1
|
||||
logits = self._scores[token_offset, :].tolist()
|
||||
logits = self._scores[token_offset, :]
|
||||
current_logprobs = Llama.logits_to_logprobs(logits)
|
||||
sorted_logprobs = list(
|
||||
sorted(
|
||||
|
@ -1785,9 +1785,8 @@ class Llama:
|
|||
self.detokenize([token]).decode("utf-8", errors="ignore")
|
||||
for token in all_tokens
|
||||
]
|
||||
all_logprobs = [
|
||||
Llama.logits_to_logprobs(row.tolist()) for row in self._scores
|
||||
][token_offset:]
|
||||
all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
|
||||
# TODO: may be able to change this loop to use np.take_along_dim
|
||||
for token, token_str, logprobs_token in zip(
|
||||
all_tokens, all_token_strs, all_logprobs
|
||||
):
|
||||
|
@ -2282,7 +2281,7 @@ class Llama:
|
|||
|
||||
@staticmethod
|
||||
def logits_to_logprobs(
|
||||
logits: Union[List, npt.NDArray[np.single]], axis: int = -1
|
||||
logits: Union[npt.NDArray[np.single], List], axis: int = -1
|
||||
) -> npt.NDArray[np.single]:
|
||||
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html
|
||||
logits_maxs: np.ndarray = np.amax(logits, axis=axis, keepdims=True)
|
||||
|
@ -2293,7 +2292,7 @@ class Llama:
|
|||
subtract_maxs = np.subtract(logits, logits_maxs, dtype=np.single)
|
||||
exp = np.exp(subtract_maxs)
|
||||
# Suppress warnings about log of zero
|
||||
with np.errstate(divide='ignore'):
|
||||
with np.errstate(divide="ignore"):
|
||||
summed = np.sum(exp, axis=axis, keepdims=True)
|
||||
out = np.log(summed)
|
||||
return subtract_maxs - out
|
||||
|
|
Loading…
Reference in a new issue