Merge pull request #309 from MeouSker77/fix-CJK

Fix CJK and emoji stream output
This commit is contained in:
Andrei 2023-08-29 06:58:10 -04:00 committed by GitHub
commit bae44ec8bf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1004,13 +1004,15 @@ class Llama:
break break
token_end_position = 0 token_end_position = 0
if logprobs is not None:
# not sure how to handle this branch when dealing
# with CJK output, so keep it unchanged
for token in remaining_tokens: for token in remaining_tokens:
token_end_position += len(self.detokenize([token])) token_end_position += len(self.detokenize([token]))
# Check if stop sequence is in the token # Check if stop sequence is in the token
if token_end_position >= (remaining_length - first_stop_position): if token_end_position > (remaining_length - first_stop_position):
break break
logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
token_str = self.detokenize([token]).decode( token_str = self.detokenize([token]).decode(
"utf-8", errors="ignore" "utf-8", errors="ignore"
) )
@ -1060,6 +1062,41 @@ class Llama:
} }
], ],
} }
else:
while len(remaining_tokens) > 0:
decode_success = False
for i in range(1, len(remaining_tokens) + 1):
tokens = remaining_tokens[:i]
try:
bs = self.detokenize(tokens)
text = bs.decode('utf-8')
decode_success = True
break
except UnicodeError:
pass
if not decode_success:
# all remaining tokens cannot be decoded to a UTF-8 character
break
token_end_position += len(bs)
if token_end_position > (remaining_length - first_stop_position):
break
remaining_tokens = remaining_tokens[i:]
returned_tokens += i
yield {
"id": completion_id,
"object": "text_completion",
"created": created,
"model": model_name,
"choices": [
{
"text": text,
"index": 0,
"logprobs": None,
"finish_reason": None,
}
],
}
if len(completion_tokens) >= max_tokens: if len(completion_tokens) >= max_tokens:
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)