Fix logprobs for completions and implement for streaming logprobs.

This commit is contained in:
Andrei Betlen 2023-05-19 02:20:27 -04:00
parent a634a2453b
commit 17d4271b04

View file

@ -710,22 +710,56 @@ class Llama:
# We want to avoid yielding any characters from # We want to avoid yielding any characters from
# the generated text if they are part of a stop # the generated text if they are part of a stop
# sequence. # sequence.
longest = 0 first_stop_position = 0
for s in stop_sequences: for s in stop_sequences:
for i in range(len(s), 0, -1): for i in range(len(s), 0, -1):
if all_text.endswith(s[:i]): if all_text.endswith(s[:i]):
if i > longest: if i > first_stop_position:
longest = i first_stop_position = i
break break
offset = 0 token_end_position = 0
remaining_tokens = completion_tokens[returned_tokens:] remaining_tokens = completion_tokens[returned_tokens:]
remaining_length = len(self.detokenize(remaining_tokens)) remaining_length = len(self.detokenize(remaining_tokens))
for token in remaining_tokens: for token in remaining_tokens:
offset += len(self.detokenize([token])) token_end_position += len(self.detokenize([token]))
# Check if stop sequence is not in the token # Check if stop sequence is in the token
if offset >= (remaining_length - longest - 1): if token_end_position >= (remaining_length - first_stop_position - 1):
break break
logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
token_str = self.detokenize([token]).decode(
"utf-8", errors="ignore"
)
text_offset = len(prompt) + len(
self.detokenize(completion_tokens[:returned_tokens])
)
token_offset = len(prompt_tokens) + returned_tokens
logits = self.eval_logits[token_offset - 1]
current_logprobs = Llama.logits_to_logprobs(logits)
sorted_logprobs = list(
sorted(
zip(current_logprobs, range(len(current_logprobs))),
reverse=True,
)
)
top_logprob = {
self.detokenize([llama_cpp.llama_token(i)]).decode(
"utf-8", errors="ignore"
): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: current_logprobs[int(token)]})
logprobs_or_none = {
"tokens": [
self.detokenize([token]).decode(
"utf-8", errors="ignore"
)
],
"text_offset": [text_offset],
"token_logprobs": [sorted_logprobs[int(token)][0]],
"top_logprobs": [top_logprob],
}
returned_tokens += 1 returned_tokens += 1
yield { yield {
"id": completion_id, "id": completion_id,
@ -738,7 +772,7 @@ class Llama:
"utf-8", errors="ignore" "utf-8", errors="ignore"
), ),
"index": 0, "index": 0,
"logprobs": None, "logprobs": logprobs_or_none,
"finish_reason": None, "finish_reason": None,
} }
], ],
@ -766,13 +800,48 @@ class Llama:
else: else:
end = len(all_text) end = len(all_text)
offset = 0 token_end_position = 0
for token in remaining_tokens: for token in remaining_tokens:
offset += len(self.detokenize([token])) token_end_position += len(self.detokenize([token]))
if offset >= end:
logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
token_str = self.detokenize([token]).decode(
"utf-8", errors="ignore"
)
text_offset = len(prompt) + len(
self.detokenize(completion_tokens[:returned_tokens])
)
token_offset = len(prompt_tokens) + returned_tokens - 1
logits = self.eval_logits[token_offset]
current_logprobs = Llama.logits_to_logprobs(logits)
sorted_logprobs = list(
sorted(
zip(current_logprobs, range(len(current_logprobs))),
reverse=True,
)
)
top_logprob = {
self.detokenize([llama_cpp.llama_token(i)]).decode(
"utf-8", errors="ignore"
): logprob
for logprob, i in sorted_logprobs[:logprobs]
}
top_logprob.update({token_str: current_logprobs[int(token)]})
logprobs_or_none = {
"tokens": [
self.detokenize([token]).decode("utf-8", errors="ignore")
],
"text_offset": [text_offset],
"token_logprobs": [sorted_logprobs[int(token)][0]],
"top_logprobs": [top_logprob],
}
if token_end_position >= end:
last_text = self.detokenize([token]) last_text = self.detokenize([token])
if offset == end - 1: if token_end_position == end - 1:
break break
returned_tokens += 1
yield { yield {
"id": completion_id, "id": completion_id,
"object": "text_completion", "object": "text_completion",
@ -781,10 +850,10 @@ class Llama:
"choices": [ "choices": [
{ {
"text": last_text[ "text": last_text[
: len(last_text) - (offset - end) : len(last_text) - (token_end_position - end)
].decode("utf-8", errors="ignore"), ].decode("utf-8", errors="ignore"),
"index": 0, "index": 0,
"logprobs": None, "logprobs": logprobs_or_none,
"finish_reason": finish_reason, "finish_reason": finish_reason,
} }
], ],
@ -802,7 +871,7 @@ class Llama:
"utf-8", errors="ignore" "utf-8", errors="ignore"
), ),
"index": 0, "index": 0,
"logprobs": None, "logprobs": logprobs_or_none,
"finish_reason": finish_reason "finish_reason": finish_reason
if returned_tokens == len(completion_tokens) if returned_tokens == len(completion_tokens)
else None, else None,
@ -821,13 +890,19 @@ class Llama:
logprobs_or_none: Optional[CompletionLogprobs] = None logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None: if logprobs is not None:
text_offset = 0 text_offset = 0 if echo else len(prompt)
token_offset = 0 if echo else len(prompt_tokens[1:])
text_offsets: List[int] = [] text_offsets: List[int] = []
token_logprobs: List[float] = [] token_logprobs: List[Optional[float]] = []
tokens: List[str] = [] tokens: List[str] = []
top_logprobs: List[Dict[str, float]] = [] top_logprobs: List[Optional[Dict[str, float]]] = []
if echo:
# Remove leading BOS token
all_tokens = prompt_tokens[1:] + completion_tokens
else:
all_tokens = completion_tokens
all_tokens = prompt_tokens + completion_tokens
all_token_strs = [ all_token_strs = [
self.detokenize([token]).decode("utf-8", errors="ignore") self.detokenize([token]).decode("utf-8", errors="ignore")
for token in all_tokens for token in all_tokens
@ -835,7 +910,7 @@ class Llama:
all_logprobs = [ all_logprobs = [
Llama.logits_to_logprobs(list(map(float, row))) Llama.logits_to_logprobs(list(map(float, row)))
for row in self.eval_logits for row in self.eval_logits
] ][token_offset:]
for token, token_str, logprobs_token in zip( for token, token_str, logprobs_token in zip(
all_tokens, all_token_strs, all_logprobs all_tokens, all_token_strs, all_logprobs
): ):
@ -848,14 +923,20 @@ class Llama:
) )
) )
token_logprobs.append(sorted_logprobs[int(token)][0]) token_logprobs.append(sorted_logprobs[int(token)][0])
top_logprob = { top_logprob: Optional[Dict[str, float]] = {
self.detokenize([llama_cpp.llama_token(i)]).decode( self.detokenize([llama_cpp.llama_token(i)]).decode(
"utf-8", errors="ignore" "utf-8", errors="ignore"
): logprob ): logprob
for logprob, i in sorted_logprobs[:logprobs] for logprob, i in sorted_logprobs[:logprobs]
} }
top_logprob.update({token_str: sorted_logprobs[int(token)][0]}) top_logprob.update({token_str: logprobs_token[int(token)]})
top_logprobs.append(top_logprob) top_logprobs.append(top_logprob)
# Weird idosincracy of the OpenAI API where
# token_logprobs and top_logprobs are null for
# the first token.
if echo and len(all_tokens) > 0:
token_logprobs[0] = None
top_logprobs[0] = None
logprobs_or_none = { logprobs_or_none = {
"tokens": tokens, "tokens": tokens,
"text_offset": text_offsets, "text_offset": text_offsets,