From 6b2e0e05b40f23e43a31720bcd26ec137a74507a Mon Sep 17 00:00:00 2001 From: kddubey Date: Mon, 18 Dec 2023 11:28:12 -0800 Subject: [PATCH] perf: Don't convert logprobs arrays to lists (#1021) --- llama_cpp/llama.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 491d7cb..e9f3dba 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1552,7 +1552,7 @@ class Llama: self.detokenize(completion_tokens[:returned_tokens]) ) token_offset = len(prompt_tokens) + returned_tokens - logits = self._scores[token_offset - 1, :].tolist() + logits = self._scores[token_offset - 1, :] current_logprobs = Llama.logits_to_logprobs(logits) sorted_logprobs = list( sorted( @@ -1671,7 +1671,7 @@ class Llama: self.detokenize(completion_tokens[:returned_tokens]) ) token_offset = len(prompt_tokens) + returned_tokens - 1 - logits = self._scores[token_offset, :].tolist() + logits = self._scores[token_offset, :] current_logprobs = Llama.logits_to_logprobs(logits) sorted_logprobs = list( sorted( @@ -1785,9 +1785,8 @@ class Llama: self.detokenize([token]).decode("utf-8", errors="ignore") for token in all_tokens ] - all_logprobs = [ - Llama.logits_to_logprobs(row.tolist()) for row in self._scores - ][token_offset:] + all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:] + # TODO: may be able to change this loop to use np.take_along_dim for token, token_str, logprobs_token in zip( all_tokens, all_token_strs, all_logprobs ): @@ -2282,7 +2281,7 @@ class Llama: @staticmethod def logits_to_logprobs( - logits: Union[List, npt.NDArray[np.single]], axis: int = -1 + logits: Union[npt.NDArray[np.single], List], axis: int = -1 ) -> npt.NDArray[np.single]: # https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.log_softmax.html logits_maxs: np.ndarray = np.amax(logits, axis=axis, keepdims=True) @@ -2293,7 +2292,7 @@ class Llama: subtract_maxs = np.subtract(logits, logits_maxs, dtype=np.single) exp = np.exp(subtract_maxs) # Suppress warnings about log of zero - with np.errstate(divide='ignore'): + with np.errstate(divide="ignore"): summed = np.sum(exp, axis=axis, keepdims=True) out = np.log(summed) return subtract_maxs - out