Merge branch 'main' of github.com:abetlen/llama_cpp_python into main

This commit is contained in:
Andrei Betlen 2024-02-09 13:32:31 -05:00
commit 4abb8c9386
2 changed files with 13 additions and 9 deletions

View file

@ -950,8 +950,7 @@ class Llama:
if stream: if stream:
remaining_tokens = completion_tokens[returned_tokens:] remaining_tokens = completion_tokens[returned_tokens:]
prev_tokens = completion_tokens[:returned_tokens] remaining_text = self.detokenize(remaining_tokens)
remaining_text = self.detokenize(completion_tokens, prev_tokens)
remaining_length = len(remaining_text) remaining_length = len(remaining_text)
# We want to avoid yielding any characters from # We want to avoid yielding any characters from
@ -973,13 +972,13 @@ class Llama:
for token in remaining_tokens: for token in remaining_tokens:
if token == self.token_bos(): if token == self.token_bos():
continue continue
token_end_position += len(remaining_text) token_end_position += len(self.detokenize([token]))
# Check if stop sequence is in the token # Check if stop sequence is in the token
if token_end_position > ( if token_end_position > (
remaining_length - first_stop_position remaining_length - first_stop_position
): ):
break break
token_str = remaining_text.decode( token_str = self.detokenize([token]).decode(
"utf-8", errors="ignore" "utf-8", errors="ignore"
) )
text_offset = len(prompt) + len( text_offset = len(prompt) + len(
@ -1004,7 +1003,11 @@ class Llama:
} }
top_logprob.update({token_str: current_logprobs[int(token)]}) top_logprob.update({token_str: current_logprobs[int(token)]})
logprobs_or_none = { logprobs_or_none = {
"tokens": [token_str], "tokens": [
self.detokenize([token]).decode(
"utf-8", errors="ignore"
)
],
"text_offset": [text_offset], "text_offset": [text_offset],
"token_logprobs": [current_logprobs[int(token)]], "token_logprobs": [current_logprobs[int(token)]],
"top_logprobs": [top_logprob], "top_logprobs": [top_logprob],
@ -1017,7 +1020,9 @@ class Llama:
"model": model_name, "model": model_name,
"choices": [ "choices": [
{ {
"text": token_str, "text": self.detokenize([token]).decode(
"utf-8", errors="ignore"
),
"index": 0, "index": 0,
"logprobs": logprobs_or_none, "logprobs": logprobs_or_none,
"finish_reason": None, "finish_reason": None,
@ -1029,7 +1034,7 @@ class Llama:
decode_success = False decode_success = False
for i in range(1, len(remaining_tokens) + 1): for i in range(1, len(remaining_tokens) + 1):
try: try:
bs = remaining_text bs = self.detokenize(remaining_tokens[:i])
ts = bs.decode("utf-8") ts = bs.decode("utf-8")
decode_success = True decode_success = True
break break
@ -1065,7 +1070,6 @@ class Llama:
if len(completion_tokens) >= max_tokens: if len(completion_tokens) >= max_tokens:
text = self.detokenize(completion_tokens) text = self.detokenize(completion_tokens)
finish_reason = "length" finish_reason = "length"
break break

2
vendor/llama.cpp vendored

@ -1 +1 @@
Subproject commit b08f22c882a1443e6b97081f3ce718a4d1a741f8 Subproject commit 8e6a9d2de0096af7120606c74ee2f26684e87b41