This commit is contained in:
Andrei Betlen 2023-11-20 22:50:59 -05:00
parent 2c2afa320f
commit a34d480141

View file

@ -1323,7 +1323,9 @@ class Llama:
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
created: int = int(time.time())
completion_tokens: List[int] = []
# If prompt is empty, initialize completion with BOS token to avoid
# detokenization including a space at the beginning of the completion
completion_tokens: List[int] = [] if len(prompt) > 0 else [self.token_bos()]
# Add blank space to start of prompt to match OG llama tokenizer
prompt_tokens: List[int] = (
(
@ -1459,6 +1461,8 @@ class Llama:
# not sure how to handle this branch when dealing
# with CJK output, so keep it unchanged
for token in remaining_tokens:
if token == self.token_bos():
continue
token_end_position += len(self.detokenize([token]))
# Check if stop sequence is in the token
if token_end_position > (
@ -1582,6 +1586,8 @@ class Llama:
logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
if token == self.token_bos():
continue
token_str = self.detokenize([token]).decode(
"utf-8", errors="ignore"
)
@ -1709,6 +1715,8 @@ class Llama:
for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs
):
if token == self.token_bos():
continue
text_offsets.append(text_offset)
text_offset += len(token_str)
tokens.append(token_str)