From a34d48014192771d2e308a76c22f33bc0318d983 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Mon, 20 Nov 2023 22:50:59 -0500 Subject: [PATCH] Fix #929 --- llama_cpp/llama.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 2e18b47..00f2bcb 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -1323,7 +1323,9 @@ class Llama: completion_id: str = f"cmpl-{str(uuid.uuid4())}" created: int = int(time.time()) - completion_tokens: List[int] = [] + # If prompt is empty, initialize completion with BOS token to avoid + # detokenization including a space at the beginning of the completion + completion_tokens: List[int] = [] if len(prompt) > 0 else [self.token_bos()] # Add blank space to start of prompt to match OG llama tokenizer prompt_tokens: List[int] = ( ( @@ -1459,6 +1461,8 @@ class Llama: # not sure how to handle this branch when dealing # with CJK output, so keep it unchanged for token in remaining_tokens: + if token == self.token_bos(): + continue token_end_position += len(self.detokenize([token])) # Check if stop sequence is in the token if token_end_position > ( @@ -1582,6 +1586,8 @@ class Llama: logprobs_or_none: Optional[CompletionLogprobs] = None if logprobs is not None: + if token == self.token_bos(): + continue token_str = self.detokenize([token]).decode( "utf-8", errors="ignore" ) @@ -1709,6 +1715,8 @@ class Llama: for token, token_str, logprobs_token in zip( all_tokens, all_token_strs, all_logprobs ): + if token == self.token_bos(): + continue text_offsets.append(text_offset) text_offset += len(token_str) tokens.append(token_str)