fix: revert _create_completions.
This commit is contained in:
parent
dfc1b17341
commit
e16f06e6eb
1 changed files with 12 additions and 8 deletions
|
@ -948,8 +948,7 @@ class Llama:
|
|||
|
||||
if stream:
|
||||
remaining_tokens = completion_tokens[returned_tokens:]
|
||||
prev_tokens = completion_tokens[:returned_tokens]
|
||||
remaining_text = self.detokenize(completion_tokens, prev_tokens)
|
||||
remaining_text = self.detokenize(remaining_tokens)
|
||||
remaining_length = len(remaining_text)
|
||||
|
||||
# We want to avoid yielding any characters from
|
||||
|
@ -971,13 +970,13 @@ class Llama:
|
|||
for token in remaining_tokens:
|
||||
if token == self.token_bos():
|
||||
continue
|
||||
token_end_position += len(remaining_text)
|
||||
token_end_position += len(self.detokenize([token]))
|
||||
# Check if stop sequence is in the token
|
||||
if token_end_position > (
|
||||
remaining_length - first_stop_position
|
||||
):
|
||||
break
|
||||
token_str = remaining_text.decode(
|
||||
token_str = self.detokenize([token]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
text_offset = len(prompt) + len(
|
||||
|
@ -1002,7 +1001,11 @@ class Llama:
|
|||
}
|
||||
top_logprob.update({token_str: current_logprobs[int(token)]})
|
||||
logprobs_or_none = {
|
||||
"tokens": [token_str],
|
||||
"tokens": [
|
||||
self.detokenize([token]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
)
|
||||
],
|
||||
"text_offset": [text_offset],
|
||||
"token_logprobs": [current_logprobs[int(token)]],
|
||||
"top_logprobs": [top_logprob],
|
||||
|
@ -1015,7 +1018,9 @@ class Llama:
|
|||
"model": model_name,
|
||||
"choices": [
|
||||
{
|
||||
"text": token_str,
|
||||
"text": self.detokenize([token]).decode(
|
||||
"utf-8", errors="ignore"
|
||||
),
|
||||
"index": 0,
|
||||
"logprobs": logprobs_or_none,
|
||||
"finish_reason": None,
|
||||
|
@ -1027,7 +1032,7 @@ class Llama:
|
|||
decode_success = False
|
||||
for i in range(1, len(remaining_tokens) + 1):
|
||||
try:
|
||||
bs = remaining_text
|
||||
bs = self.detokenize(remaining_tokens[:i])
|
||||
ts = bs.decode("utf-8")
|
||||
decode_success = True
|
||||
break
|
||||
|
@ -1063,7 +1068,6 @@ class Llama:
|
|||
|
||||
if len(completion_tokens) >= max_tokens:
|
||||
text = self.detokenize(completion_tokens)
|
||||
|
||||
finish_reason = "length"
|
||||
break
|
||||
|
||||
|
|
Loading…
Reference in a new issue