fix text_offset of multi-token characters (#1037)

* fix text_offsets for bytes tokens

* fix
This commit is contained in:
twaka 2023-12-22 14:03:29 +09:00 committed by GitHub
parent 33cc623346
commit 2f03fb0231
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1551,7 +1551,9 @@ class Llama:
"utf-8", errors="ignore"
)
text_offset = len(prompt) + len(
self.detokenize(completion_tokens[:returned_tokens])
self.detokenize(completion_tokens[:returned_tokens]).decode(
"utf-8", errors="ignore"
)
)
token_offset = len(prompt_tokens) + returned_tokens
logits = self._scores[token_offset - 1, :]
@ -1789,13 +1791,19 @@ class Llama:
]
all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
# TODO: may be able to change this loop to use np.take_along_dim
for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs
for idx, (token, token_str, logprobs_token) in enumerate(
zip(all_tokens, all_token_strs, all_logprobs)
):
if token == self.token_bos():
continue
text_offsets.append(text_offset)
text_offset += len(token_str)
text_offsets.append(
text_offset
+ len(
self.detokenize(all_tokens[:idx]).decode(
"utf-8", errors="ignore"
)
)
)
tokens.append(token_str)
sorted_logprobs = list(
sorted(