From 88184ed217b62649eb458afe89a22ae86b289b5e Mon Sep 17 00:00:00 2001 From: MeouSker77 Date: Wed, 9 Aug 2023 22:04:35 +0800 Subject: [PATCH] fix CJK output again --- llama_cpp/llama.py | 85 +++++++++++++++++++++++++++++++++------------- 1 file changed, 61 insertions(+), 24 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index a996d5c..1950a96 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1003,13 +1003,15 @@ class Llama: break token_end_position = 0 - for token in remaining_tokens: - token_end_position += len(self.detokenize([token])) - # Check if stop sequence is in the token - if token_end_position >= (remaining_length - first_stop_position): - break - logprobs_or_none: Optional[CompletionLogprobs] = None - if logprobs is not None: + + if logprobs is not None: + # not sure how to handle this branch when dealing + # with CJK output, so keep it unchanged + for token in remaining_tokens: + token_end_position += len(self.detokenize([token])) + # Check if stop sequence is in the token + if token_end_position > (remaining_length - first_stop_position): + break token_str = self.detokenize([token]).decode( "utf-8", errors="ignore" ) @@ -1042,23 +1044,58 @@ class Llama: "token_logprobs": [current_logprobs[int(token)]], "top_logprobs": [top_logprob], } - returned_tokens += 1 - yield { - "id": completion_id, - "object": "text_completion", - "created": created, - "model": model_name, - "choices": [ - { - "text": self.detokenize([token]).decode( - "utf-8", errors="ignore" - ), - "index": 0, - "logprobs": logprobs_or_none, - "finish_reason": None, - } - ], - } + returned_tokens += 1 + yield { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "choices": [ + { + "text": self.detokenize([token]).decode( + "utf-8", errors="ignore" + ), + "index": 0, + "logprobs": logprobs_or_none, + "finish_reason": None, + } + ], + } + else: + while len(remaining_tokens) > 0: + decode_success = False + for i in range(1, len(remaining_tokens) + 1): + tokens = remaining_tokens[:i] + try: + bs = self.detokenize(tokens) + text = bs.decode('utf-8') + decode_success = True + break + except UnicodeError: + pass + if not decode_success: + # all remaining tokens cannot be decoded to a UTF-8 character + break + token_end_position += len(bs) + if token_end_position > (remaining_length - first_stop_position): + break + remaining_tokens = remaining_tokens[i:] + returned_tokens += i + + yield { + "id": completion_id, + "object": "text_completion", + "created": created, + "model": model_name, + "choices": [ + { + "text": text, + "index": 0, + "logprobs": None, + "finish_reason": None, + } + ], + } if len(completion_tokens) >= max_tokens: text = self.detokenize(completion_tokens)