Merge pull request #309 from MeouSker77/fix-CJK

Fix CJK and emoji stream output
This commit is contained in:
Andrei 2023-08-29 06:58:10 -04:00 committed by GitHub
commit bae44ec8bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1004,13 +1004,15 @@ class Llama:
break break
token_end_position = 0 token_end_position = 0
for token in remaining_tokens:
token_end_position += len(self.detokenize([token])) if logprobs is not None:
# Check if stop sequence is in the token # not sure how to handle this branch when dealing
if token_end_position >= (remaining_length - first_stop_position): # with CJK output, so keep it unchanged
break for token in remaining_tokens:
logprobs_or_none: Optional[CompletionLogprobs] = None token_end_position += len(self.detokenize([token]))
if logprobs is not None: # Check if stop sequence is in the token
if token_end_position > (remaining_length - first_stop_position):
break
token_str = self.detokenize([token]).decode( token_str = self.detokenize([token]).decode(
"utf-8", errors="ignore" "utf-8", errors="ignore"
) )
@ -1043,23 +1045,58 @@ class Llama:
"token_logprobs": [current_logprobs[int(token)]], "token_logprobs": [current_logprobs[int(token)]],
"top_logprobs": [top_logprob], "top_logprobs": [top_logprob],
} }
returned_tokens += 1 returned_tokens += 1
yield { yield {
"id": completion_id, "id": completion_id,
"object": "text_completion", "object": "text_completion",
"created": created, "created": created,
"model": model_name, "model": model_name,
"choices": [ "choices": [
{ {
"text": self.detokenize([token]).decode( "text": self.detokenize([token]).decode(
"utf-8", errors="ignore" "utf-8", errors="ignore"
), ),
"index": 0, "index": 0,
"logprobs": logprobs_or_none, "logprobs": logprobs_or_none,
"finish_reason": None, "finish_reason": None,
} }
], ],
} }
else:
while len(remaining_tokens) > 0:
decode_success = False
for i in range(1, len(remaining_tokens) + 1):
tokens = remaining_tokens[:i]
try:
bs = self.detokenize(tokens)
text = bs.decode('utf-8')
decode_success = True
break
except UnicodeError:
pass
if not decode_success:
# all remaining tokens cannot be decoded to a UTF-8 character
break
token_end_position += len(bs)
if token_end_position > (remaining_length - first_stop_position):
break
remaining_tokens = remaining_tokens[i:]
returned_tokens += i
yield {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model_name,
"choices": [
{
"text": text,
"index": 0,
"logprobs": None,
"finish_reason": None,
}
],
}
if len(completion_tokens) >= max_tokens: if len(completion_tokens) >= max_tokens:
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)