fix: revert _create_completions.

This commit is contained in:
Andrei Betlen 2024-02-09 02:02:13 -05:00
parent dfc1b17341
commit e16f06e6eb

View file

@ -948,8 +948,7 @@ class Llama:
if stream: if stream:
remaining_tokens = completion_tokens[returned_tokens:] remaining_tokens = completion_tokens[returned_tokens:]
prev_tokens = completion_tokens[:returned_tokens] remaining_text = self.detokenize(remaining_tokens)
remaining_text = self.detokenize(completion_tokens, prev_tokens)
remaining_length = len(remaining_text) remaining_length = len(remaining_text)
# We want to avoid yielding any characters from # We want to avoid yielding any characters from
@ -971,13 +970,13 @@ class Llama:
for token in remaining_tokens: for token in remaining_tokens:
if token == self.token_bos(): if token == self.token_bos():
continue continue
token_end_position += len(remaining_text) token_end_position += len(self.detokenize([token]))
# Check if stop sequence is in the token # Check if stop sequence is in the token
if token_end_position > ( if token_end_position > (
remaining_length - first_stop_position remaining_length - first_stop_position
): ):
break break
token_str = remaining_text.decode( token_str = self.detokenize([token]).decode(
"utf-8", errors="ignore" "utf-8", errors="ignore"
) )
text_offset = len(prompt) + len( text_offset = len(prompt) + len(
@ -1002,7 +1001,11 @@ class Llama:
} }
top_logprob.update({token_str: current_logprobs[int(token)]}) top_logprob.update({token_str: current_logprobs[int(token)]})
logprobs_or_none = { logprobs_or_none = {
"tokens": [token_str], "tokens": [
self.detokenize([token]).decode(
"utf-8", errors="ignore"
)
],
"text_offset": [text_offset], "text_offset": [text_offset],
"token_logprobs": [current_logprobs[int(token)]], "token_logprobs": [current_logprobs[int(token)]],
"top_logprobs": [top_logprob], "top_logprobs": [top_logprob],
@ -1015,7 +1018,9 @@ class Llama:
"model": model_name, "model": model_name,
"choices": [ "choices": [
{ {
"text": token_str, "text": self.detokenize([token]).decode(
"utf-8", errors="ignore"
),
"index": 0, "index": 0,
"logprobs": logprobs_or_none, "logprobs": logprobs_or_none,
"finish_reason": None, "finish_reason": None,
@ -1027,7 +1032,7 @@ class Llama:
decode_success = False decode_success = False
for i in range(1, len(remaining_tokens) + 1): for i in range(1, len(remaining_tokens) + 1):
try: try:
bs = remaining_text bs = self.detokenize(remaining_tokens[:i])
ts = bs.decode("utf-8") ts = bs.decode("utf-8")
decode_success = True decode_success = True
break break
@ -1063,7 +1068,6 @@ class Llama:
if len(completion_tokens) >= max_tokens: if len(completion_tokens) >= max_tokens:
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
finish_reason = "length" finish_reason = "length"
break break