perf: Don't convert logprobs arrays to lists (#1021)
This commit is contained in:
parent
62944df142
commit
6b2e0e05b4
1 changed files with 6 additions and 7 deletions
|
@ -1552,7 +1552,7 @@ class Llama:
|
||||||
self.detokenize(completion_tokens[:returned_tokens])
|
self.detokenize(completion_tokens[:returned_tokens])
|
||||||
)
|
)
|
||||||
token_offset = len(prompt_tokens) + returned_tokens
|
token_offset = len(prompt_tokens) + returned_tokens
|
||||||
logits = self._scores[token_offset - 1, :].tolist()
|
logits = self._scores[token_offset - 1, :]
|
||||||
current_logprobs = Llama.logits_to_logprobs(logits)
|
current_logprobs = Llama.logits_to_logprobs(logits)
|
||||||
sorted_logprobs = list(
|
sorted_logprobs = list(
|
||||||
sorted(
|
sorted(
|
||||||
|
@ -1671,7 +1671,7 @@ class Llama:
|
||||||
self.detokenize(completion_tokens[:returned_tokens])
|
self.detokenize(completion_tokens[:returned_tokens])
|
||||||
)
|
)
|
||||||
token_offset = len(prompt_tokens) + returned_tokens - 1
|
token_offset = len(prompt_tokens) + returned_tokens - 1
|
||||||
logits = self._scores[token_offset, :].tolist()
|
logits = self._scores[token_offset, :]
|
||||||
current_logprobs = Llama.logits_to_logprobs(logits)
|
current_logprobs = Llama.logits_to_logprobs(logits)
|
||||||
sorted_logprobs = list(
|
sorted_logprobs = list(
|
||||||
sorted(
|
sorted(
|
||||||
|
@ -1785,9 +1785,8 @@ class Llama:
|
||||||
self.detokenize([token]).decode("utf-8", errors="ignore")
|
self.detokenize([token]).decode("utf-8", errors="ignore")
|
||||||
for token in all_tokens
|
for token in all_tokens
|
||||||
]
|
]
|
||||||
all_logprobs = [
|
all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
|
||||||
Llama.logits_to_logprobs(row.tolist()) for row in self._scores
|
# TODO: may be able to change this loop to use np.take_along_dim
|
||||||
][token_offset:]
|
|
||||||
for token, token_str, logprobs_token in zip(
|
for token, token_str, logprobs_token in zip(
|
||||||
all_tokens, all_token_strs, all_logprobs
|
all_tokens, all_token_strs, all_logprobs
|
||||||
):
|
):
|
||||||
|
@ -2282,7 +2281,7 @@ class Llama:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def logits_to_logprobs(
|
def logits_to_logprobs(
|
||||||
logits: Union[List, npt.NDArray[np.single]], axis: int = -1
|
logits: Union[npt.NDArray[np.single], List], axis: int = -1
|
||||||
) -> npt.NDArray[np.single]:
|
) -> npt.NDArray[np.single]:
|
||||||
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html
|
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html
|
||||||
logits_maxs: np.ndarray = np.amax(logits, axis=axis, keepdims=True)
|
logits_maxs: np.ndarray = np.amax(logits, axis=axis, keepdims=True)
|
||||||
|
@ -2293,7 +2292,7 @@ class Llama:
|
||||||
subtract_maxs = np.subtract(logits, logits_maxs, dtype=np.single)
|
subtract_maxs = np.subtract(logits, logits_maxs, dtype=np.single)
|
||||||
exp = np.exp(subtract_maxs)
|
exp = np.exp(subtract_maxs)
|
||||||
# Suppress warnings about log of zero
|
# Suppress warnings about log of zero
|
||||||
with np.errstate(divide='ignore'):
|
with np.errstate(divide="ignore"):
|
||||||
summed = np.sum(exp, axis=axis, keepdims=True)
|
summed = np.sum(exp, axis=axis, keepdims=True)
|
||||||
out = np.log(summed)
|
out = np.log(summed)
|
||||||
return subtract_maxs - out
|
return subtract_maxs - out
|
||||||
|
|
Loading…
Reference in a new issue