perf: Don't convert logprobs arrays to lists (#1021)

This commit is contained in:
kddubey 2023-12-18 11:28:12 -08:00 committed by GitHub
parent 62944df142
commit 6b2e0e05b4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1552,7 +1552,7 @@ class Llama:
self.detokenize(completion_tokens[:returned_tokens])
)
token_offset = len(prompt_tokens) + returned_tokens
logits = self._scores[token_offset - 1, :].tolist()
logits = self._scores[token_offset - 1, :]
current_logprobs = Llama.logits_to_logprobs(logits)
sorted_logprobs = list(
sorted(
@ -1671,7 +1671,7 @@ class Llama:
self.detokenize(completion_tokens[:returned_tokens])
)
token_offset = len(prompt_tokens) + returned_tokens - 1
logits = self._scores[token_offset, :].tolist()
logits = self._scores[token_offset, :]
current_logprobs = Llama.logits_to_logprobs(logits)
sorted_logprobs = list(
sorted(
@ -1785,9 +1785,8 @@ class Llama:
self.detokenize([token]).decode("utf-8", errors="ignore")
for token in all_tokens
]
all_logprobs = [
Llama.logits_to_logprobs(row.tolist()) for row in self._scores
][token_offset:]
all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
# TODO: may be able to change this loop to use np.take_along_dim
for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs
):
@ -2282,7 +2281,7 @@ class Llama:
@staticmethod
def logits_to_logprobs(
logits: Union[List, npt.NDArray[np.single]], axis: int = -1
logits: Union[npt.NDArray[np.single], List], axis: int = -1
) -> npt.NDArray[np.single]:
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html
logits_maxs: np.ndarray = np.amax(logits, axis=axis, keepdims=True)
@ -2293,7 +2292,7 @@ class Llama:
subtract_maxs = np.subtract(logits, logits_maxs, dtype=np.single)
exp = np.exp(subtract_maxs)
# Suppress warnings about log of zero
with np.errstate(divide='ignore'):
with np.errstate(divide="ignore"):
summed = np.sum(exp, axis=axis, keepdims=True)
out = np.log(summed)
return subtract_maxs - out