From 2f03fb02316a2877f15bac951d7133fddf19ff19 Mon Sep 17 00:00:00 2001 From: twaka Date: Fri, 22 Dec 2023 14:03:29 +0900 Subject: [PATCH] fix text_offset of multi-token characters (#1037) * fix text_offsets for bytes tokens * fix --- llama_cpp/llama.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index c2c0455..ef5e347 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1551,7 +1551,9 @@ class Llama: "utf-8", errors="ignore" ) text_offset = len(prompt) + len( - self.detokenize(completion_tokens[:returned_tokens]) + self.detokenize(completion_tokens[:returned_tokens]).decode( + "utf-8", errors="ignore" + ) ) token_offset = len(prompt_tokens) + returned_tokens logits = self._scores[token_offset - 1, :] @@ -1789,13 +1791,19 @@ class Llama: ] all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:] # TODO: may be able to change this loop to use np.take_along_dim - for token, token_str, logprobs_token in zip( - all_tokens, all_token_strs, all_logprobs + for idx, (token, token_str, logprobs_token) in enumerate( + zip(all_tokens, all_token_strs, all_logprobs) ): if token == self.token_bos(): continue - text_offsets.append(text_offset) - text_offset += len(token_str) + text_offsets.append( + text_offset + + len( + self.detokenize(all_tokens[:idx]).decode( + "utf-8", errors="ignore" + ) + ) + ) tokens.append(token_str) sorted_logprobs = list( sorted(