fix text_offset of multi-token characters (#1037)

* fix text_offsets for bytes tokens

* fix
This commit is contained in:
twaka 2023-12-22 14:03:29 +09:00 committed by GitHub
parent 33cc623346
commit 2f03fb0231
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1551,7 +1551,9 @@ class Llama:
"utf-8", errors="ignore" "utf-8", errors="ignore"
) )
text_offset = len(prompt) + len( text_offset = len(prompt) + len(
self.detokenize(completion_tokens[:returned_tokens]) self.detokenize(completion_tokens[:returned_tokens]).decode(
"utf-8", errors="ignore"
)
) )
token_offset = len(prompt_tokens) + returned_tokens token_offset = len(prompt_tokens) + returned_tokens
logits = self._scores[token_offset - 1, :] logits = self._scores[token_offset - 1, :]
@ -1789,13 +1791,19 @@ class Llama:
] ]
all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:] all_logprobs = Llama.logits_to_logprobs(self._scores)[token_offset:]
# TODO: may be able to change this loop to use np.take_along_dim # TODO: may be able to change this loop to use np.take_along_dim
for token, token_str, logprobs_token in zip( for idx, (token, token_str, logprobs_token) in enumerate(
all_tokens, all_token_strs, all_logprobs zip(all_tokens, all_token_strs, all_logprobs)
): ):
if token == self.token_bos(): if token == self.token_bos():
continue continue
text_offsets.append(text_offset) text_offsets.append(
text_offset += len(token_str) text_offset
+ len(
self.detokenize(all_tokens[:idx]).decode(
"utf-8", errors="ignore"
)
)
)
tokens.append(token_str) tokens.append(token_str)
sorted_logprobs = list( sorted_logprobs = list(
sorted( sorted(